This file is a merged representation of the entire codebase, combined into a single document by Repomix. This section contains a summary of this file. This file contains a packed representation of the entire repository's contents. It is designed to be easily consumable by AI systems for analysis, code review, or other automated processes. The content is organized as follows: 1. This summary section 2. Repository information 3. Directory structure 4. Repository files, each consisting of: - File path as an attribute - Full contents of the file - This file should be treated as read-only. Any changes should be made to the original repository files, not this packed version. - When processing this file, use the file path to distinguish between different files in the repository. - Be aware that this file may contain sensitive information. Handle it with the same level of security as you would the original repository. - Some files may have been excluded based on .gitignore rules and Repomix's configuration - Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files - Files matching patterns in .gitignore are excluded - Files matching default ignore patterns are excluded - Files are sorted by Git change count (files with more changes are at the bottom) base/ grouping/ __init__.py base.py nb.py resampling/ __init__.py base.py nb.py __init__.py accessors.py chunking.py combining.py decorators.py flex_indexing.py indexes.py indexing.py merging.py preparing.py reshaping.py wrapping.py data/ custom/ __init__.py alpaca.py av.py bento.py binance.py ccxt.py csv.py custom.py db.py duckdb.py feather.py file.py finpy.py gbm_ohlc.py gbm.py hdf.py local.py ndl.py parquet.py polygon.py random_ohlc.py random.py remote.py sql.py synthetic.py tv.py yf.py __init__.py base.py decorators.py nb.py saver.py updater.py generic/ nb/ __init__.py apply_reduce.py base.py iter_.py patterns.py records.py rolling.py sim_range.py splitting/ __init__.py base.py decorators.py nb.py purged.py sklearn_.py __init__.py accessors.py analyzable.py decorators.py drawdowns.py enums.py plots_builder.py plotting.py price_records.py ranges.py sim_range.py stats_builder.py indicators/ custom/ __init__.py adx.py atr.py bbands.py hurst.py ma.py macd.py msd.py obv.py ols.py patsim.py pivotinfo.py rsi.py sigdet.py stoch.py supertrend.py vwap.py __init__.py configs.py enums.py expr.py factory.py nb.py talib_.py labels/ generators/ __init__.py bolb.py fixlb.py fmax.py fmean.py fmin.py fstd.py meanlb.py pivotlb.py trendlb.py __init__.py enums.py nb.py ohlcv/ __init__.py accessors.py enums.py nb.py portfolio/ nb/ __init__.py analysis.py core.py ctx_helpers.py from_order_func.py from_orders.py from_signals.py iter_.py records.py pfopt/ __init__.py base.py nb.py records.py __init__.py base.py call_seq.py chunking.py decorators.py enums.py logs.py orders.py preparing.py trades.py px/ __init__.py accessors.py decorators.py records/ __init__.py base.py chunking.py col_mapper.py decorators.py mapped_array.py nb.py registries/ __init__.py ca_registry.py ch_registry.py jit_registry.py pbar_registry.py returns/ __init__.py accessors.py enums.py nb.py qs_adapter.py signals/ generators/ __init__.py ohlcstcx.py ohlcstx.py rand.py randnx.py randx.py rprob.py rprobcx.py rprobnx.py rprobx.py stcx.py stx.py __init__.py accessors.py enums.py factory.py nb.py templates/ dark.json light.json seaborn.json utils/ knowledge/ __init__.py asset_pipelines.py base_asset_funcs.py base_assets.py chatting.py custom_asset_funcs.py custom_assets.py formatting.py __init__.py annotations.py array_.py attr_.py base.py caching.py chaining.py checks.py chunking.py colors.py config.py cutting.py datetime_.py datetime_nb.py decorators.py enum_.py eval_.py execution.py figure.py formatting.py hashing.py image_.py jitting.py magic_decorators.py mapping.py math_.py merging.py module_.py params.py parsing.py path_.py pbar.py pickling.py profiling.py random_.py requests_.py schedule_.py search_.py selection.py tagging.py telegram.py template.py warnings_.py __init__.py _dtypes.py _opt_deps.py _settings.py _typing.py _version.py accessors.py This section contains the contents of the repository's files. # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules with classes and utilities for grouping.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.base.grouping.base import * from vectorbtpro.base.grouping.nb import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base classes and functions for grouping. Class `Grouper` stores metadata related to grouping index. It can return, for example, the number of groups, the start indices of groups, and other information useful for reducing operations that utilize grouping. It also allows to dynamically enable/disable/modify groups and checks whether a certain operation is permitted.""" import numpy as np import pandas as pd from pandas.core.groupby import GroupBy as PandasGroupBy from pandas.core.resample import Resampler as PandasResampler from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import indexes from vectorbtpro.base.grouping import nb from vectorbtpro.base.indexes import ExceptLevel from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils.array_ import is_sorted from vectorbtpro.utils.config import Configured from vectorbtpro.utils.decorators import cached_method from vectorbtpro.utils.template import CustomTemplate __all__ = [ "Grouper", ] GroupByT = tp.Union[None, bool, tp.Index] GrouperT = tp.TypeVar("GrouperT", bound="Grouper") class Grouper(Configured): """Class that exposes methods to group index. `group_by` can be: * boolean (False for no grouping, True for one group), * integer (level by position), * string (level by name), * sequence of integers or strings that is shorter than `index` (multiple levels), * any other sequence that has the same length as `index` (group per index). Set `allow_enable` to False to prohibit grouping if `Grouper.group_by` is None. Set `allow_disable` to False to prohibit disabling of grouping if `Grouper.group_by` is not None. Set `allow_modify` to False to prohibit modifying groups (you can still change their labels). All properties are read-only to enable caching.""" @classmethod def group_by_to_index( cls, index: tp.Index, group_by: tp.GroupByLike, def_lvl_name: tp.Hashable = "group", ) -> GroupByT: """Convert mapper `group_by` to `pd.Index`. !!! note Index and mapper must have the same length.""" if group_by is None or group_by is False: return group_by if isinstance(group_by, CustomTemplate): group_by = group_by.substitute(context=dict(index=index), strict=True, eval_id="group_by") if group_by is True: group_by = pd.Index(["group"] * len(index), name=def_lvl_name) elif isinstance(index, pd.MultiIndex) or isinstance(group_by, (ExceptLevel, int, str)): if isinstance(group_by, ExceptLevel): except_levels = group_by.value if isinstance(except_levels, (int, str)): except_levels = [except_levels] new_group_by = [] for i, name in enumerate(index.names): if i not in except_levels and name not in except_levels: new_group_by.append(name) if len(new_group_by) == 0: group_by = pd.Index(["group"] * len(index), name=def_lvl_name) else: if len(new_group_by) == 1: new_group_by = new_group_by[0] group_by = indexes.select_levels(index, new_group_by) elif isinstance(group_by, (int, str)): group_by = indexes.select_levels(index, group_by) elif ( isinstance(group_by, (tuple, list)) and not isinstance(group_by[0], pd.Index) and len(group_by) <= len(index.names) ): try: group_by = indexes.select_levels(index, group_by) except (IndexError, KeyError): pass if not isinstance(group_by, pd.Index): if isinstance(group_by[0], pd.Index): group_by = pd.MultiIndex.from_arrays(group_by) else: group_by = pd.Index(group_by, name=def_lvl_name) if len(group_by) != len(index): raise ValueError("group_by and index must have the same length") return group_by @classmethod def group_by_to_groups_and_index( cls, index: tp.Index, group_by: tp.GroupByLike, def_lvl_name: tp.Hashable = "group", ) -> tp.Tuple[tp.Array1d, tp.Index]: """Return array of group indices pointing to the original index, and grouped index.""" if group_by is None or group_by is False: return np.arange(len(index)), index group_by = cls.group_by_to_index(index, group_by, def_lvl_name) codes, uniques = pd.factorize(group_by) if not isinstance(uniques, pd.Index): new_index = pd.Index(uniques) else: new_index = uniques if isinstance(group_by, pd.MultiIndex): new_index.names = group_by.names elif isinstance(group_by, (pd.Index, pd.Series)): new_index.name = group_by.name return codes, new_index @classmethod def iter_group_lens(cls, group_lens: tp.GroupLens) -> tp.Iterator[tp.GroupIdxs]: """Iterate over indices of each group in group lengths.""" group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in range(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] yield np.arange(from_col, to_col) @classmethod def iter_group_map(cls, group_map: tp.GroupMap) -> tp.Iterator[tp.GroupIdxs]: """Iterate over indices of each group in a group map.""" group_idxs, group_lens = group_map group_start = 0 group_end = 0 for group in range(len(group_lens)): group_len = group_lens[group] group_end += group_len yield group_idxs[group_start:group_end] group_start += group_len @classmethod def from_pd_group_by( cls: tp.Type[GrouperT], pd_group_by: tp.PandasGroupByLike, **kwargs, ) -> GrouperT: """Build a `Grouper` instance from a pandas `GroupBy` object. Indices are stored under `index` and group labels under `group_by`.""" from vectorbtpro.base.merging import concat_arrays if not isinstance(pd_group_by, (PandasGroupBy, PandasResampler)): raise TypeError("pd_group_by must be an instance of GroupBy or Resampler") indices = list(pd_group_by.indices.values()) group_lens = np.asarray(list(map(len, indices))) groups = np.full(int(np.sum(group_lens)), 0, dtype=int_) group_start_idxs = np.cumsum(group_lens)[1:] - group_lens[1:] groups[group_start_idxs] = 1 groups = np.cumsum(groups) index = pd.Index(concat_arrays(indices)) group_by = pd.Index(list(pd_group_by.indices.keys()), name="group")[groups] return cls( index=index, group_by=group_by, **kwargs, ) def __init__( self, index: tp.Index, group_by: tp.GroupByLike = None, def_lvl_name: tp.Hashable = "group", allow_enable: bool = True, allow_disable: bool = True, allow_modify: bool = True, **kwargs, ) -> None: if not isinstance(index, pd.Index): index = pd.Index(index) if group_by is None or group_by is False: group_by = None else: group_by = self.group_by_to_index(index, group_by, def_lvl_name=def_lvl_name) self._index = index self._group_by = group_by self._def_lvl_name = def_lvl_name self._allow_enable = allow_enable self._allow_disable = allow_disable self._allow_modify = allow_modify Configured.__init__( self, index=index, group_by=group_by, def_lvl_name=def_lvl_name, allow_enable=allow_enable, allow_disable=allow_disable, allow_modify=allow_modify, **kwargs, ) @property def index(self) -> tp.Index: """Original index.""" return self._index @property def group_by(self) -> GroupByT: """Mapper for grouping.""" return self._group_by @property def def_lvl_name(self) -> tp.Hashable: """Default level name.""" return self._def_lvl_name @property def allow_enable(self) -> bool: """Whether to allow enabling grouping.""" return self._allow_enable @property def allow_disable(self) -> bool: """Whether to allow disabling grouping.""" return self._allow_disable @property def allow_modify(self) -> bool: """Whether to allow changing groups.""" return self._allow_modify def is_grouped(self, group_by: tp.GroupByLike = None) -> bool: """Check whether index are grouped.""" if group_by is False: return False if group_by is None: group_by = self.group_by return group_by is not None def is_grouping_enabled(self, group_by: tp.GroupByLike = None) -> bool: """Check whether grouping has been enabled.""" return self.group_by is None and self.is_grouped(group_by=group_by) def is_grouping_disabled(self, group_by: tp.GroupByLike = None) -> bool: """Check whether grouping has been disabled.""" return self.group_by is not None and not self.is_grouped(group_by=group_by) @cached_method(whitelist=True) def is_grouping_modified(self, group_by: tp.GroupByLike = None) -> bool: """Check whether grouping has been modified. Doesn't care if grouping labels have been changed.""" if group_by is None or (group_by is False and self.group_by is None): return False group_by = self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name) if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index): if not pd.Index.equals(group_by, self.group_by): groups1 = self.group_by_to_groups_and_index( self.index, group_by, def_lvl_name=self.def_lvl_name, )[0] groups2 = self.group_by_to_groups_and_index( self.index, self.group_by, def_lvl_name=self.def_lvl_name, )[0] if not np.array_equal(groups1, groups2): return True return False return True @cached_method(whitelist=True) def is_grouping_changed(self, group_by: tp.GroupByLike = None) -> bool: """Check whether grouping has been changed in any way.""" if group_by is None or (group_by is False and self.group_by is None): return False if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index): if pd.Index.equals(group_by, self.group_by): return False return True def is_group_count_changed(self, group_by: tp.GroupByLike = None) -> bool: """Check whether the number of groups has changed.""" if group_by is None or (group_by is False and self.group_by is None): return False if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index): return len(group_by) != len(self.group_by) return True def check_group_by( self, group_by: tp.GroupByLike = None, allow_enable: tp.Optional[bool] = None, allow_disable: tp.Optional[bool] = None, allow_modify: tp.Optional[bool] = None, ) -> None: """Check passed `group_by` object against restrictions.""" if allow_enable is None: allow_enable = self.allow_enable if allow_disable is None: allow_disable = self.allow_disable if allow_modify is None: allow_modify = self.allow_modify if self.is_grouping_enabled(group_by=group_by): if not allow_enable: raise ValueError("Enabling grouping is not allowed") elif self.is_grouping_disabled(group_by=group_by): if not allow_disable: raise ValueError("Disabling grouping is not allowed") elif self.is_grouping_modified(group_by=group_by): if not allow_modify: raise ValueError("Modifying groups is not allowed") def resolve_group_by(self, group_by: tp.GroupByLike = None, **kwargs) -> GroupByT: """Resolve `group_by` from either object variable or keyword argument.""" if group_by is None: group_by = self.group_by if group_by is False and self.group_by is None: group_by = None self.check_group_by(group_by=group_by, **kwargs) return self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name) @cached_method(whitelist=True) def get_groups_and_index(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.Tuple[tp.Array1d, tp.Index]: """See `Grouper.group_by_to_groups_and_index`.""" group_by = self.resolve_group_by(group_by=group_by, **kwargs) return self.group_by_to_groups_and_index(self.index, group_by, def_lvl_name=self.def_lvl_name) def get_groups(self, **kwargs) -> tp.Array1d: """Return groups array.""" return self.get_groups_and_index(**kwargs)[0] def get_index(self, **kwargs) -> tp.Index: """Return grouped index.""" return self.get_groups_and_index(**kwargs)[1] get_grouped_index = get_index @property def grouped_index(self) -> tp.Index: """Grouped index.""" return self.get_grouped_index() def get_stretched_index(self, **kwargs) -> tp.Index: """Return stretched index.""" groups, index = self.get_groups_and_index(**kwargs) return index[groups] def get_group_count(self, **kwargs) -> int: """Get number of groups.""" return len(self.get_index(**kwargs)) @cached_method(whitelist=True) def is_sorted(self, group_by: tp.GroupByLike = None, **kwargs) -> bool: """Return whether groups are monolithic, sorted.""" group_by = self.resolve_group_by(group_by=group_by, **kwargs) groups = self.get_groups(group_by=group_by) return is_sorted(groups) @cached_method(whitelist=True) def get_group_lens(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupLens: """See `vectorbtpro.base.grouping.nb.get_group_lens_nb`.""" group_by = self.resolve_group_by(group_by=group_by, **kwargs) if group_by is None or group_by is False: # no grouping return np.full(len(self.index), 1) if not self.is_sorted(group_by=group_by): raise ValueError("group_by must form monolithic, sorted groups") groups = self.get_groups(group_by=group_by) func = jit_reg.resolve_option(nb.get_group_lens_nb, jitted) return func(groups) def get_group_start_idxs(self, **kwargs) -> tp.Array1d: """Get first index of each group as an array.""" group_lens = self.get_group_lens(**kwargs) return np.cumsum(group_lens) - group_lens def get_group_end_idxs(self, **kwargs) -> tp.Array1d: """Get end index of each group as an array.""" group_lens = self.get_group_lens(**kwargs) return np.cumsum(group_lens) @cached_method(whitelist=True) def get_group_map(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupMap: """See get_group_map_nb.""" group_by = self.resolve_group_by(group_by=group_by, **kwargs) if group_by is None or group_by is False: # no grouping return np.arange(len(self.index)), np.full(len(self.index), 1) groups, new_index = self.get_groups_and_index(group_by=group_by) func = jit_reg.resolve_option(nb.get_group_map_nb, jitted) return func(groups, len(new_index)) def iter_group_idxs(self, **kwargs) -> tp.Iterator[tp.GroupIdxs]: """Iterate over indices of each group.""" group_map = self.get_group_map(**kwargs) return self.iter_group_map(group_map) def iter_groups( self, key_as_index: bool = False, **kwargs, ) -> tp.Iterator[tp.Tuple[tp.Union[tp.Hashable, pd.Index], tp.GroupIdxs]]: """Iterate over groups and their indices.""" index = self.get_index(**kwargs) for group, group_idxs in enumerate(self.iter_group_idxs(**kwargs)): if key_as_index: yield index[[group]], group_idxs else: yield index[group], group_idxs def select_groups(self, group_idxs: tp.Array1d, jitted: tp.JittedOption = None) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Select groups. Returns indices and new group array. Automatically decides whether to use group lengths or group map.""" from vectorbtpro.base.reshaping import to_1d_array if self.is_sorted(): func = jit_reg.resolve_option(nb.group_lens_select_nb, jitted) new_group_idxs, new_groups = func(self.get_group_lens(), to_1d_array(group_idxs)) # faster else: func = jit_reg.resolve_option(nb.group_map_select_nb, jitted) new_group_idxs, new_groups = func(self.get_group_map(), to_1d_array(group_idxs)) # more flexible return new_group_idxs, new_groups # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for grouping.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.registries.jit_registry import register_jitted __all__ = [] GroupByT = tp.Union[None, bool, tp.Index] @register_jitted(cache=True) def get_group_lens_nb(groups: tp.Array1d) -> tp.GroupLens: """Return the count per group. !!! note Columns must form monolithic, sorted groups. For unsorted groups, use `get_group_map_nb`.""" result = np.empty(groups.shape[0], dtype=int_) j = 0 last_group = -1 group_len = 0 for i in range(groups.shape[0]): cur_group = groups[i] if cur_group < last_group: raise ValueError("Columns must form monolithic, sorted groups") if cur_group != last_group: if last_group != -1: # Process previous group result[j] = group_len j += 1 group_len = 0 last_group = cur_group group_len += 1 if i == groups.shape[0] - 1: # Process last group result[j] = group_len j += 1 group_len = 0 return result[:j] @register_jitted(cache=True) def get_group_map_nb(groups: tp.Array1d, n_groups: int) -> tp.GroupMap: """Build the map between groups and indices. Returns an array with indices segmented by group and an array with group lengths. Works well for unsorted group arrays.""" group_lens_out = np.full(n_groups, 0, dtype=int_) for g in range(groups.shape[0]): group = groups[g] group_lens_out[group] += 1 group_start_idxs = np.cumsum(group_lens_out) - group_lens_out group_idxs_out = np.empty((groups.shape[0],), dtype=int_) group_i = np.full(n_groups, 0, dtype=int_) for g in range(groups.shape[0]): group = groups[g] group_idxs_out[group_start_idxs[group] + group_i[group]] = g group_i[group] += 1 return group_idxs_out, group_lens_out @register_jitted(cache=True) def group_lens_select_nb(group_lens: tp.GroupLens, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Perform indexing on a sorted array using group lengths. Returns indices of elements corresponding to groups in `new_groups` and a new group array.""" group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens n_values = np.sum(group_lens[new_groups]) indices_out = np.empty(n_values, dtype=int_) group_arr_out = np.empty(n_values, dtype=int_) j = 0 for c in range(new_groups.shape[0]): from_r = group_start_idxs[new_groups[c]] to_r = group_end_idxs[new_groups[c]] if from_r == to_r: continue rang = np.arange(from_r, to_r) indices_out[j : j + rang.shape[0]] = rang group_arr_out[j : j + rang.shape[0]] = c j += rang.shape[0] return indices_out, group_arr_out @register_jitted(cache=True) def group_map_select_nb(group_map: tp.GroupMap, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Perform indexing using group map.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens total_count = np.sum(group_lens[new_groups]) indices_out = np.empty(total_count, dtype=int_) group_arr_out = np.empty(total_count, dtype=int_) j = 0 for new_group_i in range(len(new_groups)): new_group = new_groups[new_group_i] group_len = group_lens[new_group] if group_len == 0: continue group_start_idx = group_start_idxs[new_group] idxs = group_idxs[group_start_idx : group_start_idx + group_len] indices_out[j : j + group_len] = idxs group_arr_out[j : j + group_len] = new_group_i j += group_len return indices_out, group_arr_out @register_jitted(cache=True) def group_by_evenly_nb(n: int, n_splits: int) -> tp.Array1d: """Get `group_by` from evenly splitting a space of values.""" out = np.empty(n, dtype=int_) for i in range(n): out[i] = i * n_splits // n + n_splits // (2 * n) return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules with classes and utilities for resampling.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.base.resampling.base import * from vectorbtpro.base.resampling.nb import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base classes and functions for resampling.""" import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.indexes import repeat_index from vectorbtpro.base.resampling import nb from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.config import Configured from vectorbtpro.utils.decorators import cached_property, hybrid_method from vectorbtpro.utils.warnings_ import warn __all__ = [ "Resampler", ] ResamplerT = tp.TypeVar("ResamplerT", bound="Resampler") class Resampler(Configured): """Class that exposes methods to resample index. Args: source_index (index_like): Index being resampled. target_index (index_like): Index resulted from resampling. source_freq (frequency_like or bool): Frequency or date offset of the source index. Set to False to force-set the frequency to None. target_freq (frequency_like or bool): Frequency or date offset of the target index. Set to False to force-set the frequency to None. silence_warnings (bool): Whether to silence all warnings.""" def __init__( self, source_index: tp.IndexLike, target_index: tp.IndexLike, source_freq: tp.Union[None, bool, tp.FrequencyLike] = None, target_freq: tp.Union[None, bool, tp.FrequencyLike] = None, silence_warnings: tp.Optional[bool] = None, **kwargs, ) -> None: source_index = dt.prepare_dt_index(source_index) target_index = dt.prepare_dt_index(target_index) infer_source_freq = True if isinstance(source_freq, bool): if not source_freq: infer_source_freq = False source_freq = None infer_target_freq = True if isinstance(target_freq, bool): if not target_freq: infer_target_freq = False target_freq = None if infer_source_freq: source_freq = dt.infer_index_freq(source_index, freq=source_freq) if infer_target_freq: target_freq = dt.infer_index_freq(target_index, freq=target_freq) self._source_index = source_index self._target_index = target_index self._source_freq = source_freq self._target_freq = target_freq self._silence_warnings = silence_warnings Configured.__init__( self, source_index=source_index, target_index=target_index, source_freq=source_freq, target_freq=target_freq, silence_warnings=silence_warnings, **kwargs, ) @classmethod def from_pd_resampler( cls: tp.Type[ResamplerT], pd_resampler: tp.PandasResampler, source_freq: tp.Optional[tp.FrequencyLike] = None, silence_warnings: bool = True, ) -> ResamplerT: """Build `Resampler` from [pandas.core.resample.Resampler](https://pandas.pydata.org/docs/reference/resampling.html). """ target_index = pd_resampler.count().index return cls( source_index=pd_resampler.obj.index, target_index=target_index, source_freq=source_freq, target_freq=None, silence_warnings=silence_warnings, ) @classmethod def from_pd_resample( cls: tp.Type[ResamplerT], source_index: tp.IndexLike, *args, source_freq: tp.Optional[tp.FrequencyLike] = None, silence_warnings: bool = True, **kwargs, ) -> ResamplerT: """Build `Resampler` from [pandas.DataFrame.resample](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html). """ pd_resampler = pd.Series(index=source_index, dtype=object).resample(*args, **kwargs) return cls.from_pd_resampler(pd_resampler, source_freq=source_freq, silence_warnings=silence_warnings) @classmethod def from_date_range( cls: tp.Type[ResamplerT], source_index: tp.IndexLike, *args, source_freq: tp.Optional[tp.FrequencyLike] = None, silence_warnings: tp.Optional[bool] = None, **kwargs, ) -> ResamplerT: """Build `Resampler` from `vectorbtpro.utils.datetime_.date_range`.""" target_index = dt.date_range(*args, **kwargs) return cls( source_index=source_index, target_index=target_index, source_freq=source_freq, target_freq=None, silence_warnings=silence_warnings, ) @property def source_index(self) -> tp.Index: """Index being resampled.""" return self._source_index @property def target_index(self) -> tp.Index: """Index resulted from resampling.""" return self._target_index @property def source_freq(self) -> tp.AnyPandasFrequency: """Frequency or date offset of the source index.""" return self._source_freq @property def target_freq(self) -> tp.AnyPandasFrequency: """Frequency or date offset of the target index.""" return self._target_freq @property def silence_warnings(self) -> bool: """Frequency or date offset of the target index.""" from vectorbtpro._settings import settings resampling_cfg = settings["resampling"] silence_warnings = self._silence_warnings if silence_warnings is None: silence_warnings = resampling_cfg["silence_warnings"] return silence_warnings def get_np_source_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency: """Frequency or date offset of the source index in NumPy format.""" if silence_warnings is None: silence_warnings = self.silence_warnings warned = False source_freq = self.source_freq if source_freq is not None: if not isinstance(source_freq, (int, float)): try: source_freq = dt.to_timedelta64(source_freq) except ValueError as e: if not silence_warnings: warn(f"Cannot convert {source_freq} to np.timedelta64. Setting to None.") warned = True source_freq = None if source_freq is None: if not warned and not silence_warnings: warn("Using right bound of source index without frequency. Set source frequency.") return source_freq def get_np_target_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency: """Frequency or date offset of the target index in NumPy format.""" if silence_warnings is None: silence_warnings = self.silence_warnings warned = False target_freq = self.target_freq if target_freq is not None: if not isinstance(target_freq, (int, float)): try: target_freq = dt.to_timedelta64(target_freq) except ValueError as e: if not silence_warnings: warn(f"Cannot convert {target_freq} to np.timedelta64. Setting to None.") warned = True target_freq = None if target_freq is None: if not warned and not silence_warnings: warn("Using right bound of target index without frequency. Set target frequency.") return target_freq @classmethod def get_lbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index: """Get the left bound of a datetime index. If `freq` is None, calculates the leftmost bound.""" index = dt.prepare_dt_index(index) checks.assert_instance_of(index, pd.DatetimeIndex) if freq is not None: return index.shift(-1, freq=freq) + pd.Timedelta(1, "ns") min_ts = pd.DatetimeIndex([pd.Timestamp.min.tz_localize(index.tz)]) return (index[:-1] + pd.Timedelta(1, "ns")).append(min_ts) @classmethod def get_rbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index: """Get the right bound of a datetime index. If `freq` is None, calculates the rightmost bound.""" index = dt.prepare_dt_index(index) checks.assert_instance_of(index, pd.DatetimeIndex) if freq is not None: return index.shift(1, freq=freq) - pd.Timedelta(1, "ns") max_ts = pd.DatetimeIndex([pd.Timestamp.max.tz_localize(index.tz)]) return (index[1:] - pd.Timedelta(1, "ns")).append(max_ts) @cached_property def source_lbound_index(self) -> tp.Index: """Get the left bound of the source datetime index.""" return self.get_lbound_index(self.source_index, freq=self.source_freq) @cached_property def source_rbound_index(self) -> tp.Index: """Get the right bound of the source datetime index.""" return self.get_rbound_index(self.source_index, freq=self.source_freq) @cached_property def target_lbound_index(self) -> tp.Index: """Get the left bound of the target datetime index.""" return self.get_lbound_index(self.target_index, freq=self.target_freq) @cached_property def target_rbound_index(self) -> tp.Index: """Get the right bound of the target datetime index.""" return self.get_rbound_index(self.target_index, freq=self.target_freq) def map_to_target_index( self, before: bool = False, raise_missing: bool = True, return_index: bool = True, jitted: tp.JittedOption = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.Union[tp.Array1d, tp.Index]: """See `vectorbtpro.base.resampling.nb.map_to_target_index_nb`.""" target_freq = self.get_np_target_freq(silence_warnings=silence_warnings) func = jit_reg.resolve_option(nb.map_to_target_index_nb, jitted) mapped_arr = func( self.source_index.values, self.target_index.values, target_freq=target_freq, before=before, raise_missing=raise_missing, ) if return_index: nan_mask = mapped_arr == -1 if nan_mask.any(): mapped_index = self.source_index.to_series().copy() mapped_index[nan_mask] = np.nan mapped_index[~nan_mask] = self.target_index[mapped_arr] mapped_index = pd.Index(mapped_index) else: mapped_index = self.target_index[mapped_arr] return mapped_index return mapped_arr def index_difference( self, reverse: bool = False, return_index: bool = True, jitted: tp.JittedOption = None, ) -> tp.Union[tp.Array1d, tp.Index]: """See `vectorbtpro.base.resampling.nb.index_difference_nb`.""" func = jit_reg.resolve_option(nb.index_difference_nb, jitted) if reverse: mapped_arr = func(self.target_index.values, self.source_index.values) else: mapped_arr = func(self.source_index.values, self.target_index.values) if return_index: return self.target_index[mapped_arr] return mapped_arr def map_index_to_source_ranges( self, before: bool = False, jitted: tp.JittedOption = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """See `vectorbtpro.base.resampling.nb.map_index_to_source_ranges_nb`. If `Resampler.target_freq` is a date offset, sets is to None and gives a warning. Raises another warning is `target_freq` is None.""" target_freq = self.get_np_target_freq(silence_warnings=silence_warnings) func = jit_reg.resolve_option(nb.map_index_to_source_ranges_nb, jitted) return func( self.source_index.values, self.target_index.values, target_freq=target_freq, before=before, ) @hybrid_method def map_bounds_to_source_ranges( cls_or_self, source_index: tp.Optional[tp.IndexLike] = None, target_lbound_index: tp.Optional[tp.IndexLike] = None, target_rbound_index: tp.Optional[tp.IndexLike] = None, closed_lbound: bool = True, closed_rbound: bool = False, skip_not_found: bool = False, jitted: tp.JittedOption = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """See `vectorbtpro.base.resampling.nb.map_bounds_to_source_ranges_nb`. Either `target_lbound_index` or `target_rbound_index` must be set. Set `target_lbound_index` and `target_rbound_index` to 'pandas' to use `Resampler.get_lbound_index` and `Resampler.get_rbound_index` respectively. Also, both allow providing a single datetime string and will automatically broadcast to the `Resampler.target_index`.""" if not isinstance(cls_or_self, type): if target_lbound_index is None and target_rbound_index is None: raise ValueError("Either target_lbound_index or target_rbound_index must be set") if target_lbound_index is not None: if isinstance(target_lbound_index, str) and target_lbound_index.lower() == "pandas": target_lbound_index = cls_or_self.target_lbound_index else: target_lbound_index = dt.prepare_dt_index(target_lbound_index) target_rbound_index = cls_or_self.target_index if target_rbound_index is not None: target_lbound_index = cls_or_self.target_index if isinstance(target_rbound_index, str) and target_rbound_index.lower() == "pandas": target_rbound_index = cls_or_self.target_rbound_index else: target_rbound_index = dt.prepare_dt_index(target_rbound_index) if len(target_lbound_index) == 1 and len(target_rbound_index) > 1: target_lbound_index = repeat_index(target_lbound_index, len(target_rbound_index)) elif len(target_lbound_index) > 1 and len(target_rbound_index) == 1: target_rbound_index = repeat_index(target_rbound_index, len(target_lbound_index)) else: source_index = dt.prepare_dt_index(source_index) target_lbound_index = dt.prepare_dt_index(target_lbound_index) target_rbound_index = dt.prepare_dt_index(target_rbound_index) checks.assert_len_equal(target_rbound_index, target_lbound_index) func = jit_reg.resolve_option(nb.map_bounds_to_source_ranges_nb, jitted) return func( source_index.values, target_lbound_index.values, target_rbound_index.values, closed_lbound=closed_lbound, closed_rbound=closed_rbound, skip_not_found=skip_not_found, ) def resample_source_mask( self, source_mask: tp.ArrayLike, jitted: tp.JittedOption = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.Array1d: """See `vectorbtpro.base.resampling.nb.resample_source_mask_nb`.""" from vectorbtpro.base.reshaping import broadcast_array_to if silence_warnings is None: silence_warnings = self.silence_warnings source_mask = broadcast_array_to(source_mask, len(self.source_index)) source_freq = self.get_np_source_freq(silence_warnings=silence_warnings) target_freq = self.get_np_target_freq(silence_warnings=silence_warnings) func = jit_reg.resolve_option(nb.resample_source_mask_nb, jitted) return func( source_mask, self.source_index.values, self.target_index.values, source_freq, target_freq, ) def last_before_target_index(self, incl_source: bool = True, jitted: tp.JittedOption = None) -> tp.Array1d: """See `vectorbtpro.base.resampling.nb.last_before_target_index_nb`.""" func = jit_reg.resolve_option(nb.last_before_target_index_nb, jitted) return func(self.source_index.values, self.target_index.values, incl_source=incl_source) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for resampling.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils.datetime_nb import d_td __all__ = [] @register_jitted(cache=True) def date_range_nb( start: np.datetime64, end: np.datetime64, freq: np.timedelta64 = d_td, incl_left: bool = True, incl_right: bool = True, ) -> tp.Array1d: """Generate a datetime index with nanosecond precision from a date range. Inspired by [pandas.date_range](https://pandas.pydata.org/docs/reference/api/pandas.date_range.html).""" values_len = int(np.floor((end - start) / freq)) + 1 values = np.empty(values_len, dtype="datetime64[ns]") for i in range(values_len): values[i] = start + i * freq if start == end: if not incl_left and not incl_right: values = values[1:-1] else: if not incl_left or not incl_right: if not incl_left and len(values) and values[0] == start: values = values[1:] if not incl_right and len(values) and values[-1] == end: values = values[:-1] return values @register_jitted(cache=True) def map_to_target_index_nb( source_index: tp.Array1d, target_index: tp.Array1d, target_freq: tp.Optional[tp.Scalar] = None, before: bool = False, raise_missing: bool = True, ) -> tp.Array1d: """Get the index of each from `source_index` in `target_index`. If `before` is True, applied on elements that come before and including that index. Otherwise, applied on elements that come after and including that index. If `raise_missing` is True, will throw an error if an index cannot be mapped. Otherwise, the element for that index becomes -1.""" out = np.empty(len(source_index), dtype=int_) from_j = 0 for i in range(len(source_index)): if i > 0 and source_index[i] < source_index[i - 1]: raise ValueError("Source index must be increasing") if i > 0 and source_index[i] == source_index[i - 1]: out[i] = out[i - 1] found = False for j in range(from_j, len(target_index)): if j > 0 and target_index[j] <= target_index[j - 1]: raise ValueError("Target index must be strictly increasing") if target_freq is None: if before and source_index[i] <= target_index[j]: if j == 0 or target_index[j - 1] < source_index[i]: out[i] = from_j = j found = True break if not before and target_index[j] <= source_index[i]: if j == len(target_index) - 1 or source_index[i] < target_index[j + 1]: out[i] = from_j = j found = True break else: if before and target_index[j] - target_freq < source_index[i] <= target_index[j]: out[i] = from_j = j found = True break if not before and target_index[j] <= source_index[i] < target_index[j] + target_freq: out[i] = from_j = j found = True break if not found: if raise_missing: raise ValueError("Resampling failed: cannot map some source indices") out[i] = -1 return out @register_jitted(cache=True) def index_difference_nb( source_index: tp.Array1d, target_index: tp.Array1d, ) -> tp.Array1d: """Get the elements in `source_index` not present in `target_index`.""" out = np.empty(len(source_index), dtype=int_) from_j = 0 k = 0 for i in range(len(source_index)): if i > 0 and source_index[i] <= source_index[i - 1]: raise ValueError("Array index must be strictly increasing") found = False for j in range(from_j, len(target_index)): if j > 0 and target_index[j] <= target_index[j - 1]: raise ValueError("Target index must be strictly increasing") if source_index[i] < target_index[j]: break if source_index[i] == target_index[j]: from_j = j found = True break from_j = j if not found: out[k] = i k += 1 return out[:k] @register_jitted(cache=True) def map_index_to_source_ranges_nb( source_index: tp.Array1d, target_index: tp.Array1d, target_freq: tp.Optional[tp.Scalar] = None, before: bool = False, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Get the source bounds that correspond to each target index. If `target_freq` is not None, the right bound is limited by the frequency in `target_freq`. Otherwise, the right bound is the next index in `target_index`. Returns a 2-dim array where the first column is the absolute start index (including) and the second column is the absolute end index (excluding). If an element cannot be mapped, the start and end of the range becomes -1. !!! note Both index arrays must be increasing. Repeating values are allowed.""" range_starts_out = np.empty(len(target_index), dtype=int_) range_ends_out = np.empty(len(target_index), dtype=int_) to_j = 0 for i in range(len(target_index)): if i > 0 and target_index[i] < target_index[i - 1]: raise ValueError("Target index must be increasing") from_j = -1 for j in range(to_j, len(source_index)): if j > 0 and source_index[j] < source_index[j - 1]: raise ValueError("Array index must be increasing") found = False if target_freq is None: if before: if i == 0 and source_index[j] <= target_index[i]: found = True elif i > 0 and target_index[i - 1] < source_index[j] <= target_index[i]: found = True elif source_index[j] > target_index[i]: break else: if i == len(target_index) - 1 and target_index[i] <= source_index[j]: found = True elif i < len(target_index) - 1 and target_index[i] <= source_index[j] < target_index[i + 1]: found = True elif i < len(target_index) - 1 and source_index[j] >= target_index[i + 1]: break else: if before: if target_index[i] - target_freq < source_index[j] <= target_index[i]: found = True elif source_index[j] > target_index[i]: break else: if target_index[i] <= source_index[j] < target_index[i] + target_freq: found = True elif source_index[j] >= target_index[i] + target_freq: break if found: if from_j == -1: from_j = j to_j = j + 1 if from_j == -1: range_starts_out[i] = -1 range_ends_out[i] = -1 else: range_starts_out[i] = from_j range_ends_out[i] = to_j return range_starts_out, range_ends_out @register_jitted(cache=True) def map_bounds_to_source_ranges_nb( source_index: tp.Array1d, target_lbound_index: tp.Array1d, target_rbound_index: tp.Array1d, closed_lbound: bool = True, closed_rbound: bool = False, skip_not_found: bool = False, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Get the source bounds that correspond to the target bounds. Returns a 2-dim array where the first column is the absolute start index (including) nad the second column is the absolute end index (excluding). If an element cannot be mapped, the start and end of the range becomes -1. !!! note Both index arrays must be increasing. Repeating values are allowed.""" range_starts_out = np.empty(len(target_lbound_index), dtype=int_) range_ends_out = np.empty(len(target_lbound_index), dtype=int_) k = 0 to_j = 0 for i in range(len(target_lbound_index)): if i > 0 and target_lbound_index[i] < target_lbound_index[i - 1]: raise ValueError("Target left-bound index must be increasing") if i > 0 and target_rbound_index[i] < target_rbound_index[i - 1]: raise ValueError("Target right-bound index must be increasing") from_j = -1 for j in range(len(source_index)): if j > 0 and source_index[j] < source_index[j - 1]: raise ValueError("Array index must be increasing") found = False if closed_lbound and closed_rbound: if target_lbound_index[i] <= source_index[j] <= target_rbound_index[i]: found = True elif source_index[j] > target_rbound_index[i]: break elif closed_lbound: if target_lbound_index[i] <= source_index[j] < target_rbound_index[i]: found = True elif source_index[j] >= target_rbound_index[i]: break elif closed_rbound: if target_lbound_index[i] < source_index[j] <= target_rbound_index[i]: found = True elif source_index[j] > target_rbound_index[i]: break else: if target_lbound_index[i] < source_index[j] < target_rbound_index[i]: found = True elif source_index[j] >= target_rbound_index[i]: break if found: if from_j == -1: from_j = j to_j = j + 1 if skip_not_found: if from_j != -1: range_starts_out[k] = from_j range_ends_out[k] = to_j k += 1 else: if from_j == -1: range_starts_out[i] = -1 range_ends_out[i] = -1 else: range_starts_out[i] = from_j range_ends_out[i] = to_j if skip_not_found: return range_starts_out[:k], range_ends_out[:k] return range_starts_out, range_ends_out @register_jitted(cache=True) def resample_source_mask_nb( source_mask: tp.Array1d, source_index: tp.Array1d, target_index: tp.Array1d, source_freq: tp.Optional[tp.Scalar] = None, target_freq: tp.Optional[tp.Scalar] = None, ) -> tp.Array1d: """Resample a source mask to the target index. Becomes True only if the target bar is fully contained in the source bar. The source bar is represented by a non-interrupting sequence of True values in the source mask.""" out = np.full(len(target_index), False, dtype=np.bool_) from_j = 0 for i in range(len(target_index)): if i > 0 and target_index[i] < target_index[i - 1]: raise ValueError("Target index must be increasing") target_lbound = target_index[i] if target_freq is None: if i + 1 < len(target_index): target_rbound = target_index[i + 1] else: target_rbound = None else: target_rbound = target_index[i] + target_freq found_start = False for j in range(from_j, len(source_index)): if j > 0 and source_index[j] < source_index[j - 1]: raise ValueError("Source index must be increasing") source_lbound = source_index[j] if source_freq is None: if j + 1 < len(source_index): source_rbound = source_index[j + 1] else: source_rbound = None else: source_rbound = source_index[j] + source_freq if target_rbound is not None and target_rbound <= source_lbound: break if found_start or ( target_lbound >= source_lbound and (source_rbound is None or target_lbound < source_rbound) ): if not found_start: from_j = j found_start = True if not source_mask[j]: break if source_rbound is None or (target_rbound is not None and target_rbound <= source_rbound): out[i] = True break return out @register_jitted(cache=True) def last_before_target_index_nb( source_index: tp.Array1d, target_index: tp.Array1d, incl_source: bool = True, incl_target: bool = False, ) -> tp.Array1d: """For each source index, find the position of the last source index between the original source index and the corresponding target index.""" out = np.empty(len(source_index), dtype=int_) last_j = -1 for i in range(len(source_index)): if i > 0 and source_index[i] < source_index[i - 1]: raise ValueError("Source index must be increasing") if i > 0 and target_index[i] < target_index[i - 1]: raise ValueError("Target index must be increasing") if source_index[i] > target_index[i]: raise ValueError("Target index must be equal to or greater than source index") if last_j == -1: from_i = i + 1 else: from_i = last_j if incl_source: last_j = i else: last_j = -1 for j in range(from_i, len(source_index)): if source_index[j] < target_index[i]: last_j = j elif incl_target and source_index[j] == target_index[i]: last_j = j else: break out[i] = last_j return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules with base classes and utilities for pandas objects, such as broadcasting.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.base.grouping import * from vectorbtpro.base.resampling import * from vectorbtpro.base.accessors import * from vectorbtpro.base.chunking import * from vectorbtpro.base.combining import * from vectorbtpro.base.decorators import * from vectorbtpro.base.flex_indexing import * from vectorbtpro.base.indexes import * from vectorbtpro.base.indexing import * from vectorbtpro.base.merging import * from vectorbtpro.base.preparing import * from vectorbtpro.base.reshaping import * from vectorbtpro.base.wrapping import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom Pandas accessors for base operations with Pandas objects.""" import ast import inspect import numpy as np import pandas as pd from pandas.api.types import is_scalar from pandas.core.groupby import GroupBy as PandasGroupBy from pandas.core.resample import Resampler as PandasResampler from vectorbtpro import _typing as tp from vectorbtpro.base import combining, reshaping, indexes from vectorbtpro.base.grouping.base import Grouper from vectorbtpro.base.indexes import IndexApplier from vectorbtpro.base.indexing import ( point_idxr_defaults, range_idxr_defaults, get_index_points, get_index_ranges, ) from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer from vectorbtpro.utils.config import merge_dicts, resolve_dict, Configured from vectorbtpro.utils.decorators import hybrid_property, hybrid_method from vectorbtpro.utils.execution import Task, execute from vectorbtpro.utils.eval_ import evaluate from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods from vectorbtpro.utils.parsing import get_context_vars, get_func_arg_names from vectorbtpro.utils.template import substitute_templates from vectorbtpro.utils.warnings_ import warn if tp.TYPE_CHECKING: from vectorbtpro.data.base import Data as DataT else: DataT = "Data" if tp.TYPE_CHECKING: from vectorbtpro.generic.splitting.base import Splitter as SplitterT else: SplitterT = "Splitter" __all__ = ["BaseIDXAccessor", "BaseAccessor", "BaseSRAccessor", "BaseDFAccessor"] BaseIDXAccessorT = tp.TypeVar("BaseIDXAccessorT", bound="BaseIDXAccessor") class BaseIDXAccessor(Configured, IndexApplier): """Accessor on top of Index. Accessible via `pd.Index.vbt` and all child accessors.""" def __init__(self, obj: tp.Index, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs) -> None: checks.assert_instance_of(obj, pd.Index) Configured.__init__(self, obj=obj, freq=freq, **kwargs) self._obj = obj self._freq = freq @property def obj(self) -> tp.Index: """Pandas object.""" return self._obj def get(self) -> tp.Index: """Get `IDXAccessor.obj`.""" return self.obj # ############# Index ############# # def to_ns(self) -> tp.Array1d: """Convert index to an 64-bit integer array. Timestamps will be converted to nanoseconds.""" return dt.to_ns(self.obj) def to_period(self, freq: tp.FrequencyLike, shift: bool = False) -> pd.PeriodIndex: """Convert index to period.""" index = self.obj if isinstance(index, pd.DatetimeIndex): index = index.tz_localize(None).to_period(freq) if shift: index = index.shift() if not isinstance(index, pd.PeriodIndex): raise TypeError(f"Cannot convert index of type {type(index)} to period") return index def to_period_ts(self, *args, **kwargs) -> pd.DatetimeIndex: """Convert index to period and then to timestamp.""" new_index = self.to_period(*args, **kwargs).to_timestamp() if self.obj.tz is not None: new_index = new_index.tz_localize(self.obj.tz) return new_index def to_period_ns(self, *args, **kwargs) -> tp.Array1d: """Convert index to period and then to an 64-bit integer array. Timestamps will be converted to nanoseconds.""" return dt.to_ns(self.to_period_ts(*args, **kwargs)) @classmethod def from_values(cls, *args, **kwargs) -> tp.Index: """See `vectorbtpro.base.indexes.index_from_values`.""" return indexes.index_from_values(*args, **kwargs) def repeat(self, *args, **kwargs) -> tp.Index: """See `vectorbtpro.base.indexes.repeat_index`.""" return indexes.repeat_index(self.obj, *args, **kwargs) def tile(self, *args, **kwargs) -> tp.Index: """See `vectorbtpro.base.indexes.tile_index`.""" return indexes.tile_index(self.obj, *args, **kwargs) @hybrid_method def stack( cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], on_top: bool = False, **kwargs, ) -> tp.Index: """See `vectorbtpro.base.indexes.stack_indexes`. Set `on_top` to True to stack the second index on top of this one.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: if on_top: objs = (*others, cls_or_self.obj) else: objs = (cls_or_self.obj, *others) return indexes.stack_indexes(*objs, **kwargs) @hybrid_method def combine( cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], on_top: bool = False, **kwargs, ) -> tp.Index: """See `vectorbtpro.base.indexes.combine_indexes`. Set `on_top` to True to stack the second index on top of this one.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: if on_top: objs = (*others, cls_or_self.obj) else: objs = (cls_or_self.obj, *others) return indexes.combine_indexes(*objs, **kwargs) @hybrid_method def concat(cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], **kwargs) -> tp.Index: """See `vectorbtpro.base.indexes.concat_indexes`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) return indexes.concat_indexes(*objs, **kwargs) def apply_to_index( self: BaseIDXAccessorT, apply_func: tp.Callable, *args, **kwargs, ) -> tp.Index: return self.replace(obj=apply_func(self.obj, *args, **kwargs)).obj def align_to(self, *args, **kwargs) -> tp.IndexSlice: """See `vectorbtpro.base.indexes.align_index_to`.""" return indexes.align_index_to(self.obj, *args, **kwargs) @hybrid_method def align( cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], **kwargs, ) -> tp.Tuple[tp.IndexSlice, ...]: """See `vectorbtpro.base.indexes.align_indexes`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) return indexes.align_indexes(*objs, **kwargs) def cross_with(self, *args, **kwargs) -> tp.Tuple[tp.IndexSlice, tp.IndexSlice]: """See `vectorbtpro.base.indexes.cross_index_with`.""" return indexes.cross_index_with(self.obj, *args, **kwargs) @hybrid_method def cross( cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], **kwargs, ) -> tp.Tuple[tp.IndexSlice, ...]: """See `vectorbtpro.base.indexes.cross_indexes`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) return indexes.cross_indexes(*objs, **kwargs) x = cross def find_first_occurrence(self, *args, **kwargs) -> int: """See `vectorbtpro.base.indexes.find_first_occurrence`.""" return indexes.find_first_occurrence(self.obj, *args, **kwargs) # ############# Frequency ############# # @hybrid_method def get_freq( cls_or_self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> tp.Union[None, float, tp.PandasFrequency]: """Index frequency as `pd.Timedelta` or None if it cannot be converted.""" from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] if not isinstance(cls_or_self, type): if index is None: index = cls_or_self.obj if freq is None: freq = cls_or_self._freq else: checks.assert_not_none(index, arg_name="index") if freq is None: freq = wrapping_cfg["freq"] try: return dt.infer_index_freq(index, freq=freq, **kwargs) except Exception as e: return None @property def freq(self) -> tp.Optional[tp.PandasFrequency]: """`BaseIDXAccessor.get_freq` with date offsets and integer frequencies not allowed.""" return self.get_freq(allow_offset=True, allow_numeric=False) @property def ns_freq(self) -> tp.Optional[int]: """Convert frequency to a 64-bit integer. Timedelta will be converted to nanoseconds.""" freq = self.get_freq(allow_offset=False, allow_numeric=True) if freq is not None: freq = dt.to_ns(dt.to_timedelta64(freq)) return freq @property def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]: """Index frequency of any type.""" return self.get_freq() @hybrid_method def get_periods(cls_or_self, index: tp.Optional[tp.Index] = None) -> int: """Get the number of periods in the index, without taking into account its datetime-like properties.""" if not isinstance(cls_or_self, type): if index is None: index = cls_or_self.obj else: checks.assert_not_none(index, arg_name="index") return len(index) @property def periods(self) -> int: """`BaseIDXAccessor.get_periods` with default arguments.""" return len(self.obj) @hybrid_method def get_dt_periods( cls_or_self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.PandasFrequency] = None, ) -> float: """Get the number of periods in the index, taking into account its datetime-like properties.""" from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] if not isinstance(cls_or_self, type): if index is None: index = cls_or_self.obj else: checks.assert_not_none(index, arg_name="index") if isinstance(index, pd.DatetimeIndex): freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=True, allow_numeric=False) if freq is not None: if not isinstance(freq, pd.Timedelta): freq = dt.to_timedelta(freq, approximate=True) return (index[-1] - index[0]) / freq + 1 if not wrapping_cfg["silence_warnings"]: warn( "Couldn't parse the frequency of index. Pass it as `freq` or " "define it globally under `settings.wrapping`." ) if checks.is_number(index[0]) and checks.is_number(index[-1]): freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=False, allow_numeric=True) if checks.is_number(freq): return (index[-1] - index[0]) / freq + 1 return index[-1] - index[0] + 1 if not wrapping_cfg["silence_warnings"]: warn("Index is neither datetime-like nor integer") return cls_or_self.get_periods(index=index) @property def dt_periods(self) -> float: """`BaseIDXAccessor.get_dt_periods` with default arguments.""" return self.get_dt_periods() def arr_to_timedelta( self, a: tp.MaybeArray, to_pd: bool = False, silence_warnings: tp.Optional[bool] = None, ) -> tp.Union[pd.Index, tp.MaybeArray]: """Convert array to duration using `BaseIDXAccessor.freq`.""" from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] if silence_warnings is None: silence_warnings = wrapping_cfg["silence_warnings"] freq = self.freq if freq is None: if not silence_warnings: warn( "Couldn't parse the frequency of index. Pass it as `freq` or " "define it globally under `settings.wrapping`." ) return a if not isinstance(freq, pd.Timedelta): freq = dt.to_timedelta(freq, approximate=True) if to_pd: out = pd.to_timedelta(a * freq) else: out = a * freq return out # ############# Grouping ############# # def get_grouper(self, by: tp.AnyGroupByLike, groupby_kwargs: tp.KwargsLike = None, **kwargs) -> Grouper: """Get an index grouper of type `vectorbtpro.base.grouping.base.Grouper`. Argument `by` can be a grouper itself, an instance of Pandas `GroupBy`, an instance of Pandas `Resampler`, but also any supported input to any of them such as a frequency or an array of indices. Keyword arguments `groupby_kwargs` are passed to the Pandas methods `groupby` and `resample`, while `**kwargs` are passed to initialize `vectorbtpro.base.grouping.base.Grouper`.""" if groupby_kwargs is None: groupby_kwargs = {} if isinstance(by, Grouper): if len(kwargs) > 0: return by.replace(**kwargs) return by if isinstance(by, (PandasGroupBy, PandasResampler)): return Grouper.from_pd_group_by(by, **kwargs) try: return Grouper(index=self.obj, group_by=by, **kwargs) except Exception as e: pass if isinstance(self.obj, pd.DatetimeIndex): try: return Grouper(index=self.obj, group_by=self.to_period(dt.to_freq(by)), **kwargs) except Exception as e: pass try: pd_group_by = pd.Series(index=self.obj, dtype=object).resample(dt.to_freq(by), **groupby_kwargs) return Grouper.from_pd_group_by(pd_group_by, **kwargs) except Exception as e: pass pd_group_by = pd.Series(index=self.obj, dtype=object).groupby(by, axis=0, **groupby_kwargs) return Grouper.from_pd_group_by(pd_group_by, **kwargs) def get_resampler( self, rule: tp.AnyRuleLike, freq: tp.Optional[tp.FrequencyLike] = None, resample_kwargs: tp.KwargsLike = None, return_pd_resampler: bool = False, silence_warnings: tp.Optional[bool] = None, ) -> tp.Union[Resampler, tp.PandasResampler]: """Get an index resampler of type `vectorbtpro.base.resampling.base.Resampler`.""" if checks.is_frequency_like(rule): try: rule = dt.to_freq(rule) is_td = True except Exception as e: is_td = False if is_td: resample_kwargs = merge_dicts( dict(closed="left", label="left"), resample_kwargs, ) rule = pd.Series(index=self.obj, dtype=object).resample(rule, **resolve_dict(resample_kwargs)) if isinstance(rule, PandasResampler): if return_pd_resampler: return rule if silence_warnings is None: silence_warnings = True rule = Resampler.from_pd_resampler(rule, source_freq=self.freq, silence_warnings=silence_warnings) if return_pd_resampler: raise TypeError("Cannot convert Resampler to Pandas Resampler") if checks.is_dt_like(rule) or checks.is_iterable(rule): rule = dt.prepare_dt_index(rule) rule = Resampler( source_index=self.obj, target_index=rule, source_freq=self.freq, target_freq=freq, silence_warnings=silence_warnings, ) if isinstance(rule, Resampler): if freq is not None: rule = rule.replace(target_freq=freq) return rule raise ValueError(f"Cannot build Resampler from {rule}") # ############# Points and ranges ############# # def get_points(self, *args, **kwargs) -> tp.Array1d: """See `vectorbtpro.base.indexing.get_index_points`.""" return get_index_points(self.obj, *args, **kwargs) def get_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]: """See `vectorbtpro.base.indexing.get_index_ranges`.""" return get_index_ranges(self.obj, self.any_freq, *args, **kwargs) # ############# Splitting ############# # def split(self, *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, **kwargs) -> tp.Any: """Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_take`. !!! note Splits Pandas object, not accessor!""" from vectorbtpro.generic.splitting.base import Splitter if splitter_cls is None: splitter_cls = Splitter return splitter_cls.split_and_take(self.obj, self.obj, *args, **kwargs) def split_apply( self, apply_func: tp.Callable, *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, **kwargs, ) -> tp.Any: """Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`. !!! note Splits Pandas object, not accessor!""" from vectorbtpro.generic.splitting.base import Splitter, Takeable if splitter_cls is None: splitter_cls = Splitter return splitter_cls.split_and_apply(self.obj, apply_func, Takeable(self.obj), *args, **kwargs) # ############# Chunking ############# # def chunk( self: BaseIDXAccessorT, min_size: tp.Optional[int] = None, n_chunks: tp.Union[None, int, str] = None, chunk_len: tp.Union[None, int, str] = None, chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None, select: bool = False, return_chunk_meta: bool = False, ) -> tp.Iterator[tp.Union[tp.Index, tp.Tuple[ChunkMeta, tp.Index]]]: """Chunk this instance. If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise. For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`. !!! note Splits Pandas object, not accessor!""" if chunk_meta is None: chunk_meta = iter_chunk_meta( size=len(self.obj), min_size=min_size, n_chunks=n_chunks, chunk_len=chunk_len ) for _chunk_meta in chunk_meta: if select: array_taker = ArraySelector() else: array_taker = ArraySlicer() if return_chunk_meta: yield _chunk_meta, array_taker.take(self.obj, _chunk_meta) else: yield array_taker.take(self.obj, _chunk_meta) def chunk_apply( self: BaseIDXAccessorT, apply_func: tp.Union[str, tp.Callable], *args, chunk_kwargs: tp.KwargsLike = None, execute_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MergeableResults: """Chunk this instance and apply a function to each chunk. If `apply_func` is a string, becomes the method name. For arguments related to chunking, see `Wrapping.chunk`. !!! note Splits Pandas object, not accessor!""" if isinstance(apply_func, str): apply_func = getattr(type(self), apply_func) if chunk_kwargs is None: chunk_arg_names = set(get_func_arg_names(self.chunk)) chunk_kwargs = {} for k in list(kwargs.keys()): if k in chunk_arg_names: chunk_kwargs[k] = kwargs.pop(k) if execute_kwargs is None: execute_kwargs = {} chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs) tasks = [] keys = [] for _chunk_meta, chunk in chunks: tasks.append(Task(apply_func, chunk, *args, **kwargs)) keys.append(get_chunk_meta_key(_chunk_meta)) keys = pd.Index(keys, name="chunk_indices") return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs) BaseAccessorT = tp.TypeVar("BaseAccessorT", bound="BaseAccessor") @attach_binary_magic_methods(lambda self, other, np_func: self.combine(other, combine_func=np_func)) @attach_unary_magic_methods(lambda self, np_func: self.apply(apply_func=np_func)) class BaseAccessor(Wrapping): """Accessor on top of Series and DataFrames. Accessible via `pd.Series.vbt` and `pd.DataFrame.vbt`, and all child accessors. Series is just a DataFrame with one column, hence to avoid defining methods exclusively for 1-dim data, we will convert any Series to a DataFrame and perform matrix computation on it. Afterwards, by using `BaseAccessor.wrapper`, we will convert the 2-dim output back to a Series. `**kwargs` will be passed to `vectorbtpro.base.wrapping.ArrayWrapper`. !!! note When using magic methods, ensure that `.vbt` is called on the operand on the left if the other operand is an array. Accessors do not utilize caching. Grouping is only supported by the methods that accept the `group_by` argument. Usage: * Build a symmetric matrix: ```pycon >>> from vectorbtpro import * >>> # vectorbtpro.base.accessors.BaseAccessor.make_symmetric >>> pd.Series([1, 2, 3]).vbt.make_symmetric() 0 1 2 0 1.0 2.0 3.0 1 2.0 NaN NaN 2 3.0 NaN NaN ``` * Broadcast pandas objects: ```pycon >>> sr = pd.Series([1]) >>> df = pd.DataFrame([1, 2, 3]) >>> vbt.base.reshaping.broadcast_to(sr, df) 0 0 1 1 1 2 1 >>> sr.vbt.broadcast_to(df) 0 0 1 1 1 2 1 ``` * Many methods such as `BaseAccessor.broadcast` are both class and instance methods: ```pycon >>> from vectorbtpro.base.accessors import BaseAccessor >>> # Same as sr.vbt.broadcast(df) >>> new_sr, new_df = BaseAccessor.broadcast(sr, df) >>> new_sr 0 0 1 1 1 2 1 >>> new_df 0 0 1 1 2 2 3 ``` * Instead of explicitly importing `BaseAccessor` or any other accessor, we can use `pd_acc` instead: ```pycon >>> vbt.pd_acc.broadcast(sr, df) >>> new_sr 0 0 1 1 1 2 1 >>> new_df 0 0 1 1 2 2 3 ``` * `BaseAccessor` implements arithmetic (such as `+`), comparison (such as `>`) and logical operators (such as `&`) by forwarding the operation to `BaseAccessor.combine`: ```pycon >>> sr.vbt + df 0 0 2 1 3 2 4 ``` Many interesting use cases can be implemented this way. * For example, let's compare an array with 3 different thresholds: ```pycon >>> df.vbt > vbt.Param(np.arange(3), name='threshold') threshold 0 1 2 a2 b2 c2 a2 b2 c2 a2 b2 c2 x2 True True True False True True False False True y2 True True True True True True True True True z2 True True True True True True True True True ``` * The same using the broadcasting mechanism: ```pycon >>> df.vbt > vbt.Param(np.arange(3), name='threshold') threshold 0 1 2 a2 b2 c2 a2 b2 c2 a2 b2 c2 x2 True True True False True True False False True y2 True True True True True True True True True z2 True True True True True True True True True ``` """ @classmethod def resolve_row_stack_kwargs( cls: tp.Type[BaseAccessorT], *objs: tp.MaybeTuple[BaseAccessorT], **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `BaseAccessor` after stacking along rows.""" if "obj" not in kwargs: kwargs["obj"] = kwargs["wrapper"].row_stack_arrs( *[obj.obj for obj in objs], group_by=False, wrap=False, ) return kwargs @classmethod def resolve_column_stack_kwargs( cls: tp.Type[BaseAccessorT], *objs: tp.MaybeTuple[BaseAccessorT], reindex_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `BaseAccessor` after stacking along columns.""" if "obj" not in kwargs: kwargs["obj"] = kwargs["wrapper"].column_stack_arrs( *[obj.obj for obj in objs], reindex_kwargs=reindex_kwargs, group_by=False, wrap=False, ) return kwargs @hybrid_method def row_stack( cls_or_self: tp.MaybeType[BaseAccessorT], *objs: tp.MaybeTuple[BaseAccessorT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> BaseAccessorT: """Stack multiple `BaseAccessor` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, BaseAccessor): raise TypeError("Each object to be merged must be an instance of BaseAccessor") if wrapper_kwargs is None: wrapper_kwargs = {} if "wrapper" in kwargs and kwargs["wrapper"] is not None: wrapper = kwargs["wrapper"] if len(wrapper_kwargs) > 0: wrapper = wrapper.replace(**wrapper_kwargs) else: wrapper = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) kwargs["wrapper"] = wrapper kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) if kwargs["wrapper"].ndim == 1: return cls.sr_accessor_cls(**kwargs) return cls.df_accessor_cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[BaseAccessorT], *objs: tp.MaybeTuple[BaseAccessorT], wrapper_kwargs: tp.KwargsLike = None, reindex_kwargs: tp.KwargsLike = None, **kwargs, ) -> BaseAccessorT: """Stack multiple `BaseAccessor` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, BaseAccessor): raise TypeError("Each object to be merged must be an instance of BaseAccessor") if wrapper_kwargs is None: wrapper_kwargs = {} if "wrapper" in kwargs and kwargs["wrapper"] is not None: wrapper = kwargs["wrapper"] if len(wrapper_kwargs) > 0: wrapper = wrapper.replace(**wrapper_kwargs) else: wrapper = ArrayWrapper.column_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) kwargs["wrapper"] = wrapper kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls.df_accessor_cls(**kwargs) def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: if len(kwargs) > 0: wrapper_kwargs, kwargs = ArrayWrapper.extract_init_kwargs(**kwargs) else: wrapper_kwargs, kwargs = {}, {} if not isinstance(wrapper, ArrayWrapper): if obj is not None: raise ValueError("Must either provide wrapper and object, or only object") wrapper, obj = ArrayWrapper.from_obj(wrapper, **wrapper_kwargs), wrapper else: if obj is None: raise ValueError("Must either provide wrapper and object, or only object") if len(wrapper_kwargs) > 0: wrapper = wrapper.replace(**wrapper_kwargs) Wrapping.__init__(self, wrapper, obj=obj, **kwargs) self._obj = obj def __call__(self: BaseAccessorT, **kwargs) -> BaseAccessorT: """Allows passing arguments to the initializer.""" return self.replace(**kwargs) @hybrid_property def sr_accessor_cls(cls_or_self) -> tp.Type["BaseSRAccessor"]: """Accessor class for `pd.Series`.""" return BaseSRAccessor @hybrid_property def df_accessor_cls(cls_or_self) -> tp.Type["BaseDFAccessor"]: """Accessor class for `pd.DataFrame`.""" return BaseDFAccessor def indexing_func(self: BaseAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> BaseAccessorT: """Perform indexing on `BaseAccessor`.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) new_obj = ArrayWrapper.select_from_flex_array( self._obj, row_idxs=wrapper_meta["row_idxs"], col_idxs=wrapper_meta["col_idxs"], rows_changed=wrapper_meta["rows_changed"], columns_changed=wrapper_meta["columns_changed"], ) if checks.is_series(new_obj): return self.replace(cls_=self.sr_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj) return self.replace(cls_=self.df_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj) def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None: """Perform indexing setter on `BaseAccessor`.""" pd_indexing_setter_func(self._obj) @property def obj(self) -> tp.SeriesFrame: """Pandas object.""" if isinstance(self._obj, (pd.Series, pd.DataFrame)): if self._obj.shape == self.wrapper.shape: if self._obj.index is self.wrapper.index: if isinstance(self._obj, pd.Series) and self._obj.name == self.wrapper.name: return self._obj if isinstance(self._obj, pd.DataFrame) and self._obj.columns is self.wrapper.columns: return self._obj return self.wrapper.wrap(self._obj, group_by=False) def get(self, key: tp.Optional[tp.Hashable] = None, default: tp.Optional[tp.Any] = None) -> tp.SeriesFrame: """Get `BaseAccessor.obj`.""" if key is None: return self.obj return self.obj.get(key, default=default) @property def unwrapped(self) -> tp.SeriesFrame: return self.obj @hybrid_method def should_wrap(cls_or_self) -> bool: return False @hybrid_property def ndim(cls_or_self) -> tp.Optional[int]: """Number of dimensions in the object. 1 -> Series, 2 -> DataFrame.""" if isinstance(cls_or_self, type): return None return cls_or_self.obj.ndim @hybrid_method def is_series(cls_or_self) -> bool: """Whether the object is a Series.""" if isinstance(cls_or_self, type): raise NotImplementedError return isinstance(cls_or_self.obj, pd.Series) @hybrid_method def is_frame(cls_or_self) -> bool: """Whether the object is a DataFrame.""" if isinstance(cls_or_self, type): raise NotImplementedError return isinstance(cls_or_self.obj, pd.DataFrame) @classmethod def resolve_shape(cls, shape: tp.ShapeLike) -> tp.Shape: """Resolve shape.""" shape_2d = reshaping.to_2d_shape(shape) try: if cls.is_series() and shape_2d[1] > 1: raise ValueError("Use DataFrame accessor") except NotImplementedError: pass return shape_2d # ############# Creation ############# # @classmethod def empty(cls, shape: tp.Shape, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame: """Generate an empty Series/DataFrame of shape `shape` and fill with `fill_value`.""" if not isinstance(shape, tuple) or (isinstance(shape, tuple) and len(shape) == 1): return pd.Series(np.full(shape, fill_value), **kwargs) return pd.DataFrame(np.full(shape, fill_value), **kwargs) @classmethod def empty_like(cls, other: tp.SeriesFrame, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame: """Generate an empty Series/DataFrame like `other` and fill with `fill_value`.""" if checks.is_series(other): return cls.empty(other.shape, fill_value=fill_value, index=other.index, name=other.name, **kwargs) return cls.empty(other.shape, fill_value=fill_value, index=other.index, columns=other.columns, **kwargs) # ############# Indexes ############# # def apply_to_index( self: BaseAccessorT, *args, wrap: bool = False, **kwargs, ) -> tp.Union[BaseAccessorT, tp.SeriesFrame]: """See `vectorbtpro.base.wrapping.Wrapping.apply_to_index`. !!! note If `wrap` is False, returns Pandas object, not accessor!""" result = Wrapping.apply_to_index(self, *args, **kwargs) if wrap: return result return result.obj # ############# Setting ############# # def set( self, value_or_func: tp.Union[tp.MaybeArray, tp.Callable], *args, inplace: bool = False, columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.Optional[tp.SeriesFrame]: """Set value at each index point using `vectorbtpro.base.indexing.get_index_points`. If `value_or_func` is a function, selects all keyword arguments that were not passed to the `get_index_points` method, substitutes any templates, and passes everything to the function. As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index), `index_point` (absolute position in the index), `wrapper`, and `obj`.""" if inplace: obj = self.obj else: obj = self.obj.copy() index_points = get_index_points(self.wrapper.index, **kwargs) if callable(value_or_func): func_kwargs = {k: v for k, v in kwargs.items() if k not in point_idxr_defaults} template_context = merge_dicts(kwargs, template_context) else: func_kwargs = None if callable(value_or_func): for i in range(len(index_points)): _template_context = merge_dicts( dict( i=i, index_point=index_points[i], index_points=index_points, wrapper=self.wrapper, obj=self.obj, columns=columns, args=args, kwargs=kwargs, ), template_context, ) _func_args = substitute_templates(args, _template_context, eval_id="func_args") _func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs") v = value_or_func(*_func_args, **_func_kwargs) if self.is_series() or columns is None: obj.iloc[index_points[i]] = v elif is_scalar(columns): obj.iloc[index_points[i], obj.columns.get_indexer([columns])[0]] = v else: obj.iloc[index_points[i], obj.columns.get_indexer(columns)] = v elif checks.is_sequence(value_or_func) and not is_scalar(value_or_func): if self.is_series(): obj.iloc[index_points] = reshaping.to_1d_array(value_or_func) elif columns is None: obj.iloc[index_points] = reshaping.to_2d_array(value_or_func) elif is_scalar(columns): obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = reshaping.to_1d_array(value_or_func) else: obj.iloc[index_points, obj.columns.get_indexer(columns)] = reshaping.to_2d_array(value_or_func) else: if self.is_series() or columns is None: obj.iloc[index_points] = value_or_func elif is_scalar(columns): obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = value_or_func else: obj.iloc[index_points, obj.columns.get_indexer(columns)] = value_or_func if inplace: return None return obj def set_between( self, value_or_func: tp.Union[tp.MaybeArray, tp.Callable], *args, inplace: bool = False, columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.Optional[tp.SeriesFrame]: """Set value at each index range using `vectorbtpro.base.indexing.get_index_ranges`. If `value_or_func` is a function, selects all keyword arguments that were not passed to the `get_index_points` method, substitutes any templates, and passes everything to the function. As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index), `index_slice` (absolute slice of the index), `wrapper`, and `obj`.""" if inplace: obj = self.obj else: obj = self.obj.copy() index_ranges = get_index_ranges(self.wrapper.index, **kwargs) if callable(value_or_func): func_kwargs = {k: v for k, v in kwargs.items() if k not in range_idxr_defaults} template_context = merge_dicts(kwargs, template_context) else: func_kwargs = None for i in range(len(index_ranges[0])): if callable(value_or_func): _template_context = merge_dicts( dict( i=i, index_slice=slice(index_ranges[0][i], index_ranges[1][i]), range_starts=index_ranges[0], range_ends=index_ranges[1], wrapper=self.wrapper, obj=self.obj, columns=columns, args=args, kwargs=kwargs, ), template_context, ) _func_args = substitute_templates(args, _template_context, eval_id="func_args") _func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs") v = value_or_func(*_func_args, **_func_kwargs) elif checks.is_sequence(value_or_func) and not isinstance(value_or_func, str): v = value_or_func[i] else: v = value_or_func if self.is_series() or columns is None: obj.iloc[index_ranges[0][i] : index_ranges[1][i]] = v elif is_scalar(columns): obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer([columns])[0]] = v else: obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer(columns)] = v if inplace: return None return obj # ############# Reshaping ############# # def to_1d_array(self) -> tp.Array1d: """See `vectorbtpro.base.reshaping.to_1d` with `raw` set to True.""" return reshaping.to_1d_array(self.obj) def to_2d_array(self) -> tp.Array2d: """See `vectorbtpro.base.reshaping.to_2d` with `raw` set to True.""" return reshaping.to_2d_array(self.obj) def tile( self, n: int, keys: tp.Optional[tp.IndexLike] = None, axis: int = 1, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.base.reshaping.tile`. Set `axis` to 1 for columns and 0 for index. Use `keys` as the outermost level.""" tiled = reshaping.tile(self.obj, n, axis=axis) if keys is not None: if axis == 1: new_columns = indexes.combine_indexes([keys, self.wrapper.columns]) return ArrayWrapper.from_obj(tiled).wrap( tiled.values, **merge_dicts(dict(columns=new_columns), wrap_kwargs), ) else: new_index = indexes.combine_indexes([keys, self.wrapper.index]) return ArrayWrapper.from_obj(tiled).wrap( tiled.values, **merge_dicts(dict(index=new_index), wrap_kwargs), ) return tiled def repeat( self, n: int, keys: tp.Optional[tp.IndexLike] = None, axis: int = 1, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.base.reshaping.repeat`. Set `axis` to 1 for columns and 0 for index. Use `keys` as the outermost level.""" repeated = reshaping.repeat(self.obj, n, axis=axis) if keys is not None: if axis == 1: new_columns = indexes.combine_indexes([self.wrapper.columns, keys]) return ArrayWrapper.from_obj(repeated).wrap( repeated.values, **merge_dicts(dict(columns=new_columns), wrap_kwargs), ) else: new_index = indexes.combine_indexes([self.wrapper.index, keys]) return ArrayWrapper.from_obj(repeated).wrap( repeated.values, **merge_dicts(dict(index=new_index), wrap_kwargs), ) return repeated def align_to(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None, **kwargs) -> tp.SeriesFrame: """Align to `other` on their axes using `vectorbtpro.base.indexes.align_index_to`. Usage: ```pycon >>> df1 = pd.DataFrame( ... [[1, 2], [3, 4]], ... index=['x', 'y'], ... columns=['a', 'b'] ... ) >>> df1 a b x 1 2 y 3 4 >>> df2 = pd.DataFrame( ... [[5, 6, 7, 8], [9, 10, 11, 12]], ... index=['x', 'y'], ... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']]) ... ) >>> df2 1 2 a b a b x 5 6 7 8 y 9 10 11 12 >>> df1.vbt.align_to(df2) 1 2 a b a b x 1 2 1 2 y 3 4 3 4 ``` """ checks.assert_instance_of(other, (pd.Series, pd.DataFrame)) obj = reshaping.to_2d(self.obj) other = reshaping.to_2d(other) aligned_index = indexes.align_index_to(obj.index, other.index, **kwargs) aligned_columns = indexes.align_index_to(obj.columns, other.columns, **kwargs) obj = obj.iloc[aligned_index, aligned_columns] return self.wrapper.wrap( obj.values, group_by=False, **merge_dicts(dict(index=other.index, columns=other.columns), wrap_kwargs), ) @hybrid_method def align( cls_or_self, *others: tp.Union[tp.SeriesFrame, "BaseAccessor"], **kwargs, ) -> tp.Tuple[tp.SeriesFrame, ...]: """Align objects using `vectorbtpro.base.indexes.align_indexes`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) objs_2d = list(map(reshaping.to_2d, objs)) index_slices, new_index = indexes.align_indexes( *map(lambda x: x.index, objs_2d), return_new_index=True, **kwargs, ) column_slices, new_columns = indexes.align_indexes( *map(lambda x: x.columns, objs_2d), return_new_index=True, **kwargs, ) new_objs = [] for i in range(len(objs_2d)): new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False) if objs[i].ndim == 1 and new_obj.shape[1] == 1: new_obj = new_obj.iloc[:, 0].rename(objs[i].name) new_obj.index = new_index new_obj.columns = new_columns new_objs.append(new_obj) return tuple(new_objs) def cross_with(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame: """Align to `other` on their axes using `vectorbtpro.base.indexes.cross_index_with`. Usage: ```pycon >>> df1 = pd.DataFrame( ... [[1, 2, 3, 4], [5, 6, 7, 8]], ... index=['x', 'y'], ... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']]) ... ) >>> df1 1 2 a b a b x 1 2 3 4 y 5 6 7 8 >>> df2 = pd.DataFrame( ... [[9, 10, 11, 12], [13, 14, 15, 16]], ... index=['x', 'y'], ... columns=pd.MultiIndex.from_arrays([[3, 3, 4, 4], ['a', 'b', 'a', 'b']]) ... ) >>> df2 3 4 a b a b x 9 10 11 12 y 13 14 15 16 >>> df1.vbt.cross_with(df2) 1 2 3 4 3 4 a b a b a b a b x 1 2 1 2 3 4 3 4 y 5 6 5 6 7 8 7 8 ``` """ checks.assert_instance_of(other, (pd.Series, pd.DataFrame)) obj = reshaping.to_2d(self.obj) other = reshaping.to_2d(other) index_slices, new_index = indexes.cross_index_with( obj.index, other.index, return_new_index=True, ) column_slices, new_columns = indexes.cross_index_with( obj.columns, other.columns, return_new_index=True, ) obj = obj.iloc[index_slices[0], column_slices[0]] return self.wrapper.wrap( obj.values, group_by=False, **merge_dicts(dict(index=new_index, columns=new_columns), wrap_kwargs), ) @hybrid_method def cross(cls_or_self, *others: tp.Union[tp.SeriesFrame, "BaseAccessor"]) -> tp.Tuple[tp.SeriesFrame, ...]: """Align objects using `vectorbtpro.base.indexes.cross_indexes`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) objs_2d = list(map(reshaping.to_2d, objs)) index_slices, new_index = indexes.cross_indexes( *map(lambda x: x.index, objs_2d), return_new_index=True, ) column_slices, new_columns = indexes.cross_indexes( *map(lambda x: x.columns, objs_2d), return_new_index=True, ) new_objs = [] for i in range(len(objs_2d)): new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False) if objs[i].ndim == 1 and new_obj.shape[1] == 1: new_obj = new_obj.iloc[:, 0].rename(objs[i].name) new_obj.index = new_index new_obj.columns = new_columns new_objs.append(new_obj) return tuple(new_objs) x = cross @hybrid_method def broadcast(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any: """See `vectorbtpro.base.reshaping.broadcast`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) return reshaping.broadcast(*objs, **kwargs) def broadcast_to(self, other: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any: """See `vectorbtpro.base.reshaping.broadcast_to`.""" if isinstance(other, BaseAccessor): other = other.obj return reshaping.broadcast_to(self.obj, other, **kwargs) @hybrid_method def broadcast_combs(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any: """See `vectorbtpro.base.reshaping.broadcast_combs`.""" others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj, *others) return reshaping.broadcast_combs(*objs, **kwargs) def make_symmetric(self, *args, **kwargs) -> tp.Frame: """See `vectorbtpro.base.reshaping.make_symmetric`.""" return reshaping.make_symmetric(self.obj, *args, **kwargs) def unstack_to_array(self, *args, **kwargs) -> tp.Array: """See `vectorbtpro.base.reshaping.unstack_to_array`.""" return reshaping.unstack_to_array(self.obj, *args, **kwargs) def unstack_to_df(self, *args, **kwargs) -> tp.Frame: """See `vectorbtpro.base.reshaping.unstack_to_df`.""" return reshaping.unstack_to_df(self.obj, *args, **kwargs) def to_dict(self, *args, **kwargs) -> tp.Mapping: """See `vectorbtpro.base.reshaping.to_dict`.""" return reshaping.to_dict(self.obj, *args, **kwargs) # ############# Conversion ############# # def to_data( self, data_cls: tp.Optional[tp.Type[DataT]] = None, columns_are_symbols: bool = True, **kwargs, ) -> DataT: """Convert to a `vectorbtpro.data.base.Data` instance.""" if data_cls is None: from vectorbtpro.data.base import Data data_cls = Data return data_cls.from_data(self.obj, columns_are_symbols=columns_are_symbols, **kwargs) # ############# Combining ############# # def apply( self, apply_func: tp.Callable, *args, keep_pd: bool = False, to_2d: bool = False, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Apply a function `apply_func`. Set `keep_pd` to True to keep inputs as pandas objects, otherwise convert to NumPy arrays. Set `to_2d` to True to reshape inputs to 2-dim arrays, otherwise keep as-is. `*args` and `**kwargs` are passed to `apply_func`. !!! note The resulted array must have the same shape as the original array. Usage: * Using instance method: ```pycon >>> sr = pd.Series([1, 2], index=['x', 'y']) >>> sr.vbt.apply(lambda x: x ** 2) x 1 y 4 dtype: int64 ``` * Using class method, templates, and broadcasting: ```pycon >>> sr.vbt.apply( ... lambda x, y: x + y, ... vbt.Rep('y'), ... broadcast_named_args=dict( ... y=pd.DataFrame([[3, 4]], columns=['a', 'b']) ... ) ... ) a b x 4 5 y 5 6 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} broadcast_named_args = {"obj": self.obj, **broadcast_named_args} if len(broadcast_named_args) > 1: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: wrapper = self.wrapper if to_2d: broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()} elif not keep_pd: broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()} template_context = merge_dicts(broadcast_named_args, template_context) args = substitute_templates(args, template_context, eval_id="args") kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") out = apply_func(broadcast_named_args["obj"], *args, **kwargs) return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def concat( cls_or_self, *others: tp.ArrayLike, broadcast_kwargs: tp.KwargsLike = None, keys: tp.Optional[tp.IndexLike] = None, ) -> tp.Frame: """Concatenate with `others` along columns. Usage: ```pycon >>> sr = pd.Series([1, 2], index=['x', 'y']) >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b']) >>> sr.vbt.concat(df, keys=['c', 'd']) c d a b a b x 1 1 3 4 y 2 2 5 6 ``` """ others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others)) if isinstance(cls_or_self, type): objs = others else: objs = (cls_or_self.obj,) + others if broadcast_kwargs is None: broadcast_kwargs = {} broadcasted = reshaping.broadcast(*objs, **broadcast_kwargs) broadcasted = tuple(map(reshaping.to_2d, broadcasted)) out = pd.concat(broadcasted, axis=1, keys=keys) if not isinstance(out.columns, pd.MultiIndex) and np.all(out.columns == 0): out.columns = pd.RangeIndex(start=0, stop=len(out.columns), step=1) return out def apply_and_concat( self, ntimes: int, apply_func: tp.Callable, *args, keep_pd: bool = False, to_2d: bool = False, keys: tp.Optional[tp.IndexLike] = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeTuple[tp.Frame]: """Apply `apply_func` `ntimes` times and concatenate the results along columns. See `vectorbtpro.base.combining.apply_and_concat`. `ntimes` is the number of times to call `apply_func`, while `n_outputs` is the number of outputs to expect. `*args` and `**kwargs` are passed to `vectorbtpro.base.combining.apply_and_concat`. !!! note The resulted arrays to be concatenated must have the same shape as broadcast input arrays. Usage: * Using instance method: ```pycon >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b']) >>> df.vbt.apply_and_concat( ... 3, ... lambda i, a, b: a * b[i], ... [1, 2, 3], ... keys=['c', 'd', 'e'] ... ) c d e a b a b a b x 3 4 6 8 9 12 y 5 6 10 12 15 18 ``` * Using class method, templates, and broadcasting: ```pycon >>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z']) >>> sr.vbt.apply_and_concat( ... 3, ... lambda i, a, b: a * b + i, ... vbt.Rep('df'), ... broadcast_named_args=dict( ... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) apply_idx 0 1 2 a b c a b c a b c x 1 2 3 2 3 4 3 4 5 y 2 4 6 3 5 7 4 6 8 z 3 6 9 4 7 10 5 8 11 ``` * To change the execution engine or specify other engine-related arguments, use `execute_kwargs`: ```pycon >>> import time >>> def apply_func(i, a): ... time.sleep(1) ... return a >>> sr = pd.Series([1, 2, 3]) >>> %timeit sr.vbt.apply_and_concat(3, apply_func) 3.02 s ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) >>> %timeit sr.vbt.apply_and_concat(3, apply_func, execute_kwargs=dict(engine='dask')) 1.02 s ± 927 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} broadcast_named_args = {"obj": self.obj, **broadcast_named_args} if len(broadcast_named_args) > 1: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: wrapper = self.wrapper if to_2d: broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()} elif not keep_pd: broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()} template_context = merge_dicts(broadcast_named_args, dict(ntimes=ntimes), template_context) args = substitute_templates(args, template_context, eval_id="args") kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") out = combining.apply_and_concat(ntimes, apply_func, broadcast_named_args["obj"], *args, **kwargs) if keys is not None: new_columns = indexes.combine_indexes([keys, wrapper.columns]) else: top_columns = pd.Index(np.arange(ntimes), name="apply_idx") new_columns = indexes.combine_indexes([top_columns, wrapper.columns]) if out is None: return None wrap_kwargs = merge_dicts(dict(columns=new_columns), wrap_kwargs) if isinstance(out, list): return tuple(map(lambda x: wrapper.wrap(x, group_by=False, **wrap_kwargs), out)) return wrapper.wrap(out, group_by=False, **wrap_kwargs) @hybrid_method def combine( cls_or_self, obj: tp.MaybeTupleList[tp.Union[tp.ArrayLike, "BaseAccessor"]], combine_func: tp.Callable, *args, allow_multiple: bool = True, keep_pd: bool = False, to_2d: bool = False, concat: tp.Optional[bool] = None, keys: tp.Optional[tp.IndexLike] = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Combine with `other` using `combine_func`. Args: obj (array_like): Object(s) to combine this array with. combine_func (callable): Function to combine two arrays. Can be Numba-compiled. *args: Variable arguments passed to `combine_func`. allow_multiple (bool): Whether a tuple/list/Index will be considered as multiple objects in `other`. Takes effect only when using the instance method. keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays. to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is. concat (bool): Whether to concatenate the results along the column axis. Otherwise, pairwise combine into a Series/DataFrame of the same shape. If True, see `vectorbtpro.base.combining.combine_and_concat`. If False, see `vectorbtpro.base.combining.combine_multiple`. If None, becomes True if there are multiple objects to combine. Can only concatenate using the instance method. keys (index_like): Outermost column level. broadcast_named_args (dict): Dictionary with arguments to broadcast against each other. broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`. template_context (dict): Context used to substitute templates in `args` and `kwargs`. wrap_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. **kwargs: Keyword arguments passed to `combine_func`. !!! note If `combine_func` is Numba-compiled, will broadcast using `WRITEABLE` and `C_CONTIGUOUS` flags, which can lead to an expensive computation overhead if passed objects are large and have different shape/memory order. You also must ensure that all objects have the same data type. Also remember to bring each in `*args` to a Numba-compatible format. Usage: * Using instance method: ```pycon >>> sr = pd.Series([1, 2], index=['x', 'y']) >>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b']) >>> # using instance method >>> sr.vbt.combine(df, np.add) a b x 4 5 y 7 8 >>> sr.vbt.combine([df, df * 2], np.add, concat=False) a b x 10 13 y 17 20 >>> sr.vbt.combine([df, df * 2], np.add) combine_idx 0 1 a b a b x 4 5 7 9 y 7 8 12 14 >>> sr.vbt.combine([df, df * 2], np.add, keys=['c', 'd']) c d a b a b x 4 5 7 9 y 7 8 12 14 >>> sr.vbt.combine(vbt.Param([1, 2], name='param'), np.add) param 1 2 x 2 3 y 3 4 >>> # using class method >>> sr.vbt.combine([df, df * 2], np.add, concat=False) a b x 10 13 y 17 20 ``` * Using class method, templates, and broadcasting: ```pycon >>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z']) >>> sr.vbt.combine( ... [1, 2, 3], ... lambda x, y, z: x + y + z, ... vbt.Rep('df'), ... broadcast_named_args=dict( ... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) combine_idx 0 1 2 a b c a b c a b c x 3 4 5 4 5 6 5 6 7 y 4 5 6 5 6 7 6 7 8 z 5 6 7 6 7 8 7 8 9 ``` * To change the execution engine or specify other engine-related arguments, use `execute_kwargs`: ```pycon >>> import time >>> def combine_func(a, b): ... time.sleep(1) ... return a + b >>> sr = pd.Series([1, 2, 3]) >>> %timeit sr.vbt.combine([1, 1, 1], combine_func) 3.01 s ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) >>> %timeit sr.vbt.combine([1, 1, 1], combine_func, execute_kwargs=dict(engine='dask')) 1.02 s ± 2.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` """ from vectorbtpro.indicators.factory import IndicatorBase if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(cls_or_self, type): objs = obj else: if allow_multiple and isinstance(obj, (tuple, list)): objs = obj if concat is None: concat = True else: objs = (obj,) new_objs = [] for obj in objs: if isinstance(obj, BaseAccessor): obj = obj.obj elif isinstance(obj, IndicatorBase): obj = obj.main_output new_objs.append(obj) objs = tuple(new_objs) if not isinstance(cls_or_self, type): objs = (cls_or_self.obj,) + objs if checks.is_numba_func(combine_func): # Numba requires writeable arrays and in the same order broadcast_kwargs = merge_dicts(dict(require_kwargs=dict(requirements=["W", "C"])), broadcast_kwargs) # Broadcast and substitute templates broadcast_named_args = {**{"obj_" + str(i): obj for i, obj in enumerate(objs)}, **broadcast_named_args} broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) if to_2d: broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()} elif not keep_pd: broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()} template_context = merge_dicts(broadcast_named_args, template_context) args = substitute_templates(args, template_context, eval_id="args") kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") inputs = [broadcast_named_args["obj_" + str(i)] for i in range(len(objs))] if concat is None: concat = len(inputs) > 2 if concat: # Concat the results horizontally if isinstance(cls_or_self, type): raise TypeError("Use instance method to concatenate") out = combining.combine_and_concat(inputs[0], inputs[1:], combine_func, *args, **kwargs) if keys is not None: new_columns = indexes.combine_indexes([keys, wrapper.columns]) else: top_columns = pd.Index(np.arange(len(objs) - 1), name="combine_idx") new_columns = indexes.combine_indexes([top_columns, wrapper.columns]) return wrapper.wrap(out, **merge_dicts(dict(columns=new_columns, force_2d=True), wrap_kwargs)) else: # Combine arguments pairwise into one object out = combining.combine_multiple(inputs, combine_func, *args, **kwargs) return wrapper.wrap(out, **resolve_dict(wrap_kwargs)) @classmethod def eval( cls, expr: str, frames_back: int = 1, use_numexpr: bool = False, numexpr_kwargs: tp.KwargsLike = None, local_dict: tp.Optional[tp.Mapping] = None, global_dict: tp.Optional[tp.Mapping] = None, broadcast_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, ): """Evaluate a simple array expression element-wise using NumExpr or NumPy. If NumExpr is enables, only one-line statements are supported. Otherwise, uses `vectorbtpro.utils.eval_.evaluate`. !!! note All required variables will broadcast against each other prior to the evaluation. Usage: ```pycon >>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z']) >>> df = pd.DataFrame([[4, 5, 6]], index=['x', 'y', 'z'], columns=['a', 'b', 'c']) >>> vbt.pd_acc.eval('sr + df') a b c x 5 6 7 y 6 7 8 z 7 8 9 ``` """ if numexpr_kwargs is None: numexpr_kwargs = {} if broadcast_kwargs is None: broadcast_kwargs = {} if wrap_kwargs is None: wrap_kwargs = {} expr = inspect.cleandoc(expr) parsed = ast.parse(expr) body_nodes = list(parsed.body) load_vars = set() store_vars = set() for body_node in body_nodes: for child_node in ast.walk(body_node): if type(child_node) is ast.Name: if isinstance(child_node.ctx, ast.Load): if child_node.id not in store_vars: load_vars.add(child_node.id) if isinstance(child_node.ctx, ast.Store): store_vars.add(child_node.id) load_vars = list(load_vars) objs = get_context_vars(load_vars, frames_back=frames_back, local_dict=local_dict, global_dict=global_dict) objs = dict(zip(load_vars, objs)) objs, wrapper = reshaping.broadcast(objs, return_wrapper=True, **broadcast_kwargs) objs = {k: np.asarray(v) for k, v in objs.items()} if use_numexpr: from vectorbtpro.utils.module_ import assert_can_import assert_can_import("numexpr") import numexpr out = numexpr.evaluate(expr, local_dict=objs, **numexpr_kwargs) else: out = evaluate(expr, context=objs) return wrapper.wrap(out, **wrap_kwargs) class BaseSRAccessor(BaseAccessor): """Accessor on top of Series. Accessible via `pd.Series.vbt` and all child accessors.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, _full_init: bool = True, **kwargs, ) -> None: if _full_init: if isinstance(wrapper, ArrayWrapper): if wrapper.ndim == 2: if wrapper.shape[1] == 1: wrapper = wrapper.replace(ndim=1) else: raise TypeError("Series accessors work only one one-dimensional data") BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs) @hybrid_property def ndim(cls_or_self) -> int: return 1 @hybrid_method def is_series(cls_or_self) -> bool: return True @hybrid_method def is_frame(cls_or_self) -> bool: return False class BaseDFAccessor(BaseAccessor): """Accessor on top of DataFrames. Accessible via `pd.DataFrame.vbt` and all child accessors.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, _full_init: bool = True, **kwargs, ) -> None: if _full_init: if isinstance(wrapper, ArrayWrapper): if wrapper.ndim == 1: wrapper = wrapper.replace(ndim=2) BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs) @hybrid_property def ndim(cls_or_self) -> int: return 2 @hybrid_method def is_series(cls_or_self) -> bool: return False @hybrid_method def is_frame(cls_or_self) -> bool: return True # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Extensions for chunking of base operations.""" import uuid import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.chunking import ( ArgGetter, ArgSizer, ArraySizer, ChunkMeta, ChunkMapper, ChunkSlicer, ShapeSlicer, ArraySelector, ArraySlicer, Chunked, ) from vectorbtpro.utils.parsing import Regex __all__ = [ "GroupLensSizer", "GroupLensSlicer", "ChunkedGroupLens", "GroupLensMapper", "GroupMapSlicer", "ChunkedGroupMap", "GroupIdxsMapper", "FlexArraySizer", "FlexArraySelector", "FlexArraySlicer", "ChunkedFlexArray", "shape_gl_slicer", "flex_1d_array_gl_slicer", "flex_array_gl_slicer", "array_gl_slicer", ] class GroupLensSizer(ArgSizer): """Class for getting the size from group lengths. Argument can be either a group map tuple or a group lengths array.""" @classmethod def get_obj_size(cls, obj: tp.Union[tp.GroupLens, tp.GroupMap], single_type: tp.Optional[type] = None) -> int: """Get size of an object.""" if single_type is not None: if checks.is_instance_of(obj, single_type): return 1 if isinstance(obj, tuple): return len(obj[1]) return len(obj) def get_size(self, ann_args: tp.AnnArgs, **kwargs) -> int: return self.get_obj_size(self.get_arg(ann_args), single_type=self.single_type) class GroupLensSlicer(ChunkSlicer): """Class for slicing multiple elements from group lengths based on the chunk range.""" def get_size(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], **kwargs) -> int: return GroupLensSizer.get_obj_size(obj, single_type=self.single_type) def take(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap: if isinstance(obj, tuple): return obj[1][chunk_meta.start : chunk_meta.end] return obj[chunk_meta.start : chunk_meta.end] class ChunkedGroupLens(Chunked): """Class representing chunkable group lengths.""" def resolve_take_spec(self) -> tp.TakeSpec: if self.take_spec_missing: if self.select: raise ValueError("Selection is not supported") return GroupLensSlicer return self.take_spec def get_group_lens_slice(group_lens: tp.GroupLens, chunk_meta: ChunkMeta) -> slice: """Get slice of each chunk in group lengths.""" group_lens_cumsum = np.cumsum(group_lens[: chunk_meta.end]) start = group_lens_cumsum[chunk_meta.start] - group_lens[chunk_meta.start] end = group_lens_cumsum[-1] return slice(start, end) @define class GroupLensMapper(ChunkMapper, ArgGetter, DefineMixin): """Class for mapping chunk metadata to per-group column lengths. Argument can be either a group map tuple or a group lengths array.""" def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta: group_lens = self.get_arg(ann_args) if isinstance(group_lens, tuple): group_lens = group_lens[1] group_lens_slice = get_group_lens_slice(group_lens, chunk_meta) return ChunkMeta( uuid=str(uuid.uuid4()), idx=chunk_meta.idx, start=group_lens_slice.start, end=group_lens_slice.stop, indices=None, ) group_lens_mapper = GroupLensMapper(arg_query=Regex(r"(group_lens|group_map)")) """Default instance of `GroupLensMapper`.""" class GroupMapSlicer(ChunkSlicer): """Class for slicing multiple elements from a group map based on the chunk range.""" def get_size(self, obj: tp.GroupMap, **kwargs) -> int: return GroupLensSizer.get_obj_size(obj, single_type=self.single_type) def take(self, obj: tp.GroupMap, chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap: group_idxs, group_lens = obj group_lens = group_lens[chunk_meta.start : chunk_meta.end] return np.arange(np.sum(group_lens)), group_lens class ChunkedGroupMap(Chunked): """Class representing a chunkable group map.""" def resolve_take_spec(self) -> tp.TakeSpec: if self.take_spec_missing: if self.select: raise ValueError("Selection is not supported") return GroupMapSlicer return self.take_spec @define class GroupIdxsMapper(ChunkMapper, ArgGetter, DefineMixin): """Class for mapping chunk metadata to per-group column indices. Argument must be a group map tuple.""" def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta: group_map = self.get_arg(ann_args) group_idxs, group_lens = group_map group_lens_slice = get_group_lens_slice(group_lens, chunk_meta) return ChunkMeta( uuid=str(uuid.uuid4()), idx=chunk_meta.idx, start=None, end=None, indices=group_idxs[group_lens_slice], ) group_idxs_mapper = GroupIdxsMapper(arg_query="group_map") """Default instance of `GroupIdxsMapper`.""" class FlexArraySizer(ArraySizer): """Class for getting the size from the length of an axis in a flexible array.""" @classmethod def get_obj_size(cls, obj: tp.AnyArray, axis: int, single_type: tp.Optional[type] = None) -> int: """Get size of an object.""" if single_type is not None: if checks.is_instance_of(obj, single_type): return 1 obj = np.asarray(obj) if len(obj.shape) == 0: return 1 if axis is None: if len(obj.shape) == 1: axis = 0 checks.assert_not_none(axis, arg_name="axis") checks.assert_in(axis, (0, 1), arg_name="axis") if len(obj.shape) == 1: if axis == 1: return 1 return obj.shape[0] if len(obj.shape) == 2: if axis == 1: return obj.shape[1] return obj.shape[0] raise ValueError(f"FlexArraySizer supports max 2 dimensions, not {len(obj.shape)}") @define class FlexArraySelector(ArraySelector, DefineMixin): """Class for selecting one element from a NumPy array's axis flexibly based on the chunk index. The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb` and `vectorbtpro.base.flex_indexing.flex_select_nb`.""" def get_size(self, obj: tp.ArrayLike, **kwargs) -> int: return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type) def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]: return None def take( self, obj: tp.ArrayLike, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs, ) -> tp.ArrayLike: if np.isscalar(obj): return obj obj = np.asarray(obj) if len(obj.shape) == 0: return obj axis = self.axis if axis is None: if len(obj.shape) == 1: axis = 0 checks.assert_not_none(axis, arg_name="axis") checks.assert_in(axis, (0, 1), arg_name="axis") if len(obj.shape) == 1: if axis == 1 or obj.shape[0] == 1: return obj if self.keep_dims: return obj[chunk_meta.idx : chunk_meta.idx + 1] return obj[chunk_meta.idx] if len(obj.shape) == 2: if axis == 1: if obj.shape[1] == 1: return obj if self.keep_dims: return obj[:, chunk_meta.idx : chunk_meta.idx + 1] return obj[:, chunk_meta.idx] if obj.shape[0] == 1: return obj if self.keep_dims: return obj[chunk_meta.idx : chunk_meta.idx + 1, :] return obj[chunk_meta.idx, :] raise ValueError(f"FlexArraySelector supports max 2 dimensions, not {len(obj.shape)}") @define class FlexArraySlicer(ArraySlicer, DefineMixin): """Class for selecting one element from a NumPy array's axis flexibly based on the chunk index. The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb` and `vectorbtpro.base.flex_indexing.flex_select_nb`.""" def get_size(self, obj: tp.ArrayLike, **kwargs) -> int: return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type) def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]: return None def take( self, obj: tp.ArrayLike, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs, ) -> tp.ArrayLike: if np.isscalar(obj): return obj obj = np.asarray(obj) if len(obj.shape) == 0: return obj axis = self.axis if axis is None: if len(obj.shape) == 1: axis = 0 checks.assert_not_none(axis, arg_name="axis") checks.assert_in(axis, (0, 1), arg_name="axis") if len(obj.shape) == 1: if axis == 1 or obj.shape[0] == 1: return obj return obj[chunk_meta.start : chunk_meta.end] if len(obj.shape) == 2: if axis == 1: if obj.shape[1] == 1: return obj return obj[:, chunk_meta.start : chunk_meta.end] if obj.shape[0] == 1: return obj return obj[chunk_meta.start : chunk_meta.end, :] raise ValueError(f"FlexArraySlicer supports max 2 dimensions, not {len(obj.shape)}") class ChunkedFlexArray(Chunked): """Class representing a chunkable flexible array.""" def resolve_take_spec(self) -> tp.TakeSpec: if self.take_spec_missing: if self.select: return FlexArraySelector return FlexArraySlicer return self.take_spec shape_gl_slicer = ShapeSlicer(axis=1, mapper=group_lens_mapper) """Flexible 2-dim shape slicer along the column axis based on group lengths.""" flex_1d_array_gl_slicer = FlexArraySlicer(mapper=group_lens_mapper) """Flexible 1-dim array slicer along the column axis based on group lengths.""" flex_array_gl_slicer = FlexArraySlicer(axis=1, mapper=group_lens_mapper) """Flexible 2-dim array slicer along the column axis based on group lengths.""" array_gl_slicer = ArraySlicer(axis=1, mapper=group_lens_mapper) """2-dim array slicer along the column axis based on group lengths.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Functions for combining arrays. Combine functions combine two or more NumPy arrays using a custom function. The emphasis here is done upon stacking the results into one NumPy array - since vectorbt is all about brute-forcing large spaces of hyper-parameters, concatenating the results of each hyper-parameter combination into a single DataFrame is important. All functions are available in both Python and Numba-compiled form.""" import numpy as np from numba.typed import List from vectorbtpro import _typing as tp from vectorbtpro.registries.jit_registry import jit_reg, register_jitted from vectorbtpro.utils.execution import Task, execute from vectorbtpro.utils.template import RepFunc __all__ = [] @register_jitted def custom_apply_and_concat_none_nb( indices: tp.Array1d, apply_func_nb: tp.Callable, *args, ) -> None: """Run `apply_func_nb` that returns nothing for each index. Meant for in-place outputs.""" for i in indices: apply_func_nb(i, *args) @register_jitted def apply_and_concat_none_nb( ntimes: int, apply_func_nb: tp.Callable, *args, ) -> None: """Run `apply_func_nb` that returns nothing number of times. Uses `custom_apply_and_concat_none_nb`.""" custom_apply_and_concat_none_nb(np.arange(ntimes), apply_func_nb, *args) @register_jitted def to_2d_one_nb(a: tp.Array) -> tp.Array2d: """Expand the dimensions of the array along the axis 1.""" if a.ndim > 1: return a return np.expand_dims(a, axis=1) @register_jitted def custom_apply_and_concat_one_nb( indices: tp.Array1d, apply_func_nb: tp.Callable, *args, ) -> tp.Array2d: """Run `apply_func_nb` that returns one array for each index.""" output_0 = to_2d_one_nb(apply_func_nb(indices[0], *args)) output = np.empty((output_0.shape[0], len(indices) * output_0.shape[1]), dtype=output_0.dtype) for i in range(len(indices)): if i == 0: outputs_i = output_0 else: outputs_i = to_2d_one_nb(apply_func_nb(indices[i], *args)) output[:, i * outputs_i.shape[1] : (i + 1) * outputs_i.shape[1]] = outputs_i return output @register_jitted def apply_and_concat_one_nb( ntimes: int, apply_func_nb: tp.Callable, *args, ) -> tp.Array2d: """Run `apply_func_nb` that returns one array number of times. Uses `custom_apply_and_concat_one_nb`.""" return custom_apply_and_concat_one_nb(np.arange(ntimes), apply_func_nb, *args) @register_jitted def to_2d_multiple_nb(a: tp.Iterable[tp.Array]) -> tp.List[tp.Array2d]: """Expand the dimensions of each array in `a` along axis 1.""" lst = list() for _a in a: lst.append(to_2d_one_nb(_a)) return lst @register_jitted def custom_apply_and_concat_multiple_nb( indices: tp.Array1d, apply_func_nb: tp.Callable, *args, ) -> tp.List[tp.Array2d]: """Run `apply_func_nb` that returns multiple arrays for each index.""" outputs = list() outputs_0 = to_2d_multiple_nb(apply_func_nb(indices[0], *args)) for j in range(len(outputs_0)): outputs.append( np.empty((outputs_0[j].shape[0], len(indices) * outputs_0[j].shape[1]), dtype=outputs_0[j].dtype) ) for i in range(len(indices)): if i == 0: outputs_i = outputs_0 else: outputs_i = to_2d_multiple_nb(apply_func_nb(indices[i], *args)) for j in range(len(outputs_i)): outputs[j][:, i * outputs_i[j].shape[1] : (i + 1) * outputs_i[j].shape[1]] = outputs_i[j] return outputs @register_jitted def apply_and_concat_multiple_nb( ntimes: int, apply_func_nb: tp.Callable, *args, ) -> tp.List[tp.Array2d]: """Run `apply_func_nb` that returns multiple arrays number of times. Uses `custom_apply_and_concat_multiple_nb`.""" return custom_apply_and_concat_multiple_nb(np.arange(ntimes), apply_func_nb, *args) def apply_and_concat_each( tasks: tp.TasksLike, n_outputs: tp.Optional[int] = None, execute_kwargs: tp.KwargsLike = None, ) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]: """Apply each function on its own set of positional and keyword arguments. Executes the function using `vectorbtpro.utils.execution.execute`.""" from vectorbtpro.base.merging import column_stack_arrays if execute_kwargs is None: execute_kwargs = {} out = execute(tasks, **execute_kwargs) if n_outputs is None: if out[0] is None: n_outputs = 0 elif isinstance(out[0], (tuple, list, List)): n_outputs = len(out[0]) else: n_outputs = 1 if n_outputs == 0: return None if n_outputs == 1: if isinstance(out[0], (tuple, list, List)) and len(out[0]) == 1: out = list(map(lambda x: x[0], out)) return column_stack_arrays(out) return list(map(column_stack_arrays, zip(*out))) def apply_and_concat( ntimes: int, apply_func: tp.Callable, *args, n_outputs: tp.Optional[int] = None, jitted_loop: bool = False, jitted_warmup: bool = False, execute_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]: """Run `apply_func` function a number of times and concatenate the results depending upon how many array-like objects it generates. `apply_func` must accept arguments `i`, `*args`, and `**kwargs`. Set `jitted_loop` to True to use the JIT-compiled version. All jitted iteration functions are resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`. !!! note `n_outputs` must be set when `jitted_loop` is True. Numba doesn't support variable keyword arguments.""" if jitted_loop: if n_outputs is None: raise ValueError("Jitted iteration requires n_outputs") if n_outputs == 0: func = jit_reg.resolve(custom_apply_and_concat_none_nb) elif n_outputs == 1: func = jit_reg.resolve(custom_apply_and_concat_one_nb) else: func = jit_reg.resolve(custom_apply_and_concat_multiple_nb) if jitted_warmup: func(np.array([0]), apply_func, *args, **kwargs) def _tasks_template(chunk_meta): tasks = [] for _chunk_meta in chunk_meta: if _chunk_meta.indices is not None: chunk_indices = np.asarray(_chunk_meta.indices) else: if _chunk_meta.start is None or _chunk_meta.end is None: raise ValueError("Each chunk must have a start and an end index") chunk_indices = np.arange(_chunk_meta.start, _chunk_meta.end) tasks.append(Task(func, chunk_indices, apply_func, *args, **kwargs)) return tasks tasks = RepFunc(_tasks_template) else: tasks = [(apply_func, (i, *args), kwargs) for i in range(ntimes)] if execute_kwargs is None: execute_kwargs = {} execute_kwargs["size"] = ntimes return apply_and_concat_each( tasks, n_outputs=n_outputs, execute_kwargs=execute_kwargs, ) @register_jitted def select_and_combine_nb( i: int, obj: tp.Any, others: tp.Sequence, combine_func_nb: tp.Callable, *args, ) -> tp.AnyArray: """Numba-compiled version of `select_and_combine`.""" return combine_func_nb(obj, others[i], *args) @register_jitted def combine_and_concat_nb( obj: tp.Any, others: tp.Sequence, combine_func_nb: tp.Callable, *args, ) -> tp.Array2d: """Numba-compiled version of `combine_and_concat`.""" return apply_and_concat_one_nb(len(others), select_and_combine_nb, obj, others, combine_func_nb, *args) def select_and_combine( i: int, obj: tp.Any, others: tp.Sequence, combine_func: tp.Callable, *args, **kwargs, ) -> tp.AnyArray: """Combine `obj` with an array at position `i` in `others` using `combine_func`.""" return combine_func(obj, others[i], *args, **kwargs) def combine_and_concat( obj: tp.Any, others: tp.Sequence, combine_func: tp.Callable, *args, jitted_loop: bool = False, **kwargs, ) -> tp.Array2d: """Combine `obj` with each in `others` using `combine_func` and concatenate. `select_and_combine_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`.""" if jitted_loop: apply_func = jit_reg.resolve(select_and_combine_nb) else: apply_func = select_and_combine return apply_and_concat( len(others), apply_func, obj, others, combine_func, *args, n_outputs=1, jitted_loop=jitted_loop, **kwargs, ) @register_jitted def combine_multiple_nb( objs: tp.Sequence, combine_func_nb: tp.Callable, *args, ) -> tp.Any: """Numba-compiled version of `combine_multiple`.""" result = objs[0] for i in range(1, len(objs)): result = combine_func_nb(result, objs[i], *args) return result def combine_multiple( objs: tp.Sequence, combine_func: tp.Callable, *args, jitted_loop: bool = False, **kwargs, ) -> tp.Any: """Combine `objs` pairwise into a single object. Set `jitted_loop` to True to use the JIT-compiled version. `combine_multiple_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`. !!! note Numba doesn't support variable keyword arguments.""" if jitted_loop: func = jit_reg.resolve(combine_multiple_nb) return func(objs, combine_func, *args) result = objs[0] for i in range(1, len(objs)): result = combine_func(result, objs[i], *args, **kwargs) return result # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class decorators for base classes.""" from functools import cached_property as cachedproperty from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.config import Config, HybridConfig, merge_dicts __all__ = [] def override_arg_config(config: Config, merge_configs: bool = True) -> tp.ClassWrapper: """Class decorator to override the argument config of a class subclassing `vectorbtpro.base.preparing.BasePreparer`. Instead of overriding `_arg_config` class attribute, you can pass `config` directly to this decorator. Disable `merge_configs` to not merge, which will effectively disable field inheritance.""" def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: checks.assert_subclass_of(cls, "BasePreparer") if merge_configs: new_config = merge_dicts(cls.arg_config, config) else: new_config = config if not isinstance(new_config, Config): new_config = HybridConfig(new_config) setattr(cls, "_arg_config", new_config) return cls return wrapper def attach_arg_properties(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: """Class decorator to attach properties for arguments defined in the argument config of a `vectorbtpro.base.preparing.BasePreparer` subclass.""" checks.assert_subclass_of(cls, "BasePreparer") for arg_name, settings in cls.arg_config.items(): attach = settings.get("attach", None) broadcast = settings.get("broadcast", False) substitute_templates = settings.get("substitute_templates", False) if (isinstance(attach, bool) and attach) or (attach is None and (broadcast or substitute_templates)): if broadcast: return_type = tp.ArrayLike else: return_type = object target_pre_name = "_pre_" + arg_name if not hasattr(cls, target_pre_name): def pre_arg_prop(self, _arg_name: str = arg_name) -> return_type: return self.get_arg(_arg_name) pre_arg_prop.__name__ = target_pre_name pre_arg_prop.__module__ = cls.__module__ pre_arg_prop.__qualname__ = f"{cls.__name__}.{pre_arg_prop.__name__}" if broadcast and substitute_templates: pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting and template substitution." elif broadcast: pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting." else: pre_arg_prop.__doc__ = f"Argument `{arg_name}` before template substitution." setattr(cls, pre_arg_prop.__name__, cachedproperty(pre_arg_prop)) getattr(cls, pre_arg_prop.__name__).__set_name__(cls, pre_arg_prop.__name__) target_post_name = "_post_" + arg_name if not hasattr(cls, target_post_name): def post_arg_prop(self, _arg_name: str = arg_name) -> return_type: return self.prepare_post_arg(_arg_name) post_arg_prop.__name__ = target_post_name post_arg_prop.__module__ = cls.__module__ post_arg_prop.__qualname__ = f"{cls.__name__}.{post_arg_prop.__name__}" if broadcast and substitute_templates: post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting and template substitution." elif broadcast: post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting." else: post_arg_prop.__doc__ = f"Argument `{arg_name}` after template substitution." setattr(cls, post_arg_prop.__name__, cachedproperty(post_arg_prop)) getattr(cls, post_arg_prop.__name__).__set_name__(cls, post_arg_prop.__name__) target_name = arg_name if not hasattr(cls, target_name): def arg_prop(self, _target_post_name: str = target_post_name) -> return_type: return getattr(self, _target_post_name) arg_prop.__name__ = target_name arg_prop.__module__ = cls.__module__ arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}" arg_prop.__doc__ = f"Argument `{arg_name}`." setattr(cls, arg_prop.__name__, cachedproperty(arg_prop)) getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__) elif (isinstance(attach, bool) and attach) or attach is None: if not hasattr(cls, arg_name): def arg_prop(self, _arg_name: str = arg_name) -> tp.Any: return self.get_arg(_arg_name) arg_prop.__name__ = arg_name arg_prop.__module__ = cls.__module__ arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}" arg_prop.__doc__ = f"Argument `{arg_name}`." setattr(cls, arg_prop.__name__, cachedproperty(arg_prop)) getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__) return cls # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes and functions for flexible indexing.""" from vectorbtpro import _typing as tp from vectorbtpro._settings import settings from vectorbtpro.registries.jit_registry import register_jitted __all__ = [ "flex_select_1d_nb", "flex_select_1d_pr_nb", "flex_select_1d_pc_nb", "flex_select_nb", "flex_select_row_nb", "flex_select_col_nb", "flex_select_2d_row_nb", "flex_select_2d_col_nb", ] _rotate_rows = settings["indexing"]["rotate_rows"] _rotate_cols = settings["indexing"]["rotate_cols"] @register_jitted(cache=True) def flex_choose_i_1d_nb(arr: tp.FlexArray1d, i: int) -> int: """Choose a position in an array as if it has been broadcast against rows or columns. !!! note Array must be one-dimensional.""" if arr.shape[0] == 1: flex_i = 0 else: flex_i = i return int(flex_i) @register_jitted(cache=True) def flex_select_1d_nb(arr: tp.FlexArray1d, i: int) -> tp.Scalar: """Select an element of an array as if it has been broadcast against rows or columns. !!! note Array must be one-dimensional.""" flex_i = flex_choose_i_1d_nb(arr, i) return arr[flex_i] @register_jitted(cache=True) def flex_choose_i_pr_1d_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> int: """Choose a position in an array as if it has been broadcast against rows. Can use rotational indexing along rows. !!! note Array must be one-dimensional.""" if arr.shape[0] == 1: flex_i = 0 else: flex_i = i if rotate_rows: return int(flex_i) % arr.shape[0] return int(flex_i) @register_jitted(cache=True) def flex_choose_i_pr_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> int: """Choose a position in an array as if it has been broadcast against rows. Can use rotational indexing along rows. !!! note Array must be two-dimensional.""" if arr.shape[0] == 1: flex_i = 0 else: flex_i = i if rotate_rows: return int(flex_i) % arr.shape[0] return int(flex_i) @register_jitted(cache=True) def flex_select_1d_pr_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Scalar: """Select an element of an array as if it has been broadcast against rows. Can use rotational indexing along rows. !!! note Array must be one-dimensional.""" flex_i = flex_choose_i_pr_1d_nb(arr, i, rotate_rows=rotate_rows) return arr[flex_i] @register_jitted(cache=True) def flex_choose_i_pc_1d_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> int: """Choose a position in an array as if it has been broadcast against columns. Can use rotational indexing along columns. !!! note Array must be one-dimensional.""" if arr.shape[0] == 1: flex_col = 0 else: flex_col = col if rotate_cols: return int(flex_col) % arr.shape[0] return int(flex_col) @register_jitted(cache=True) def flex_choose_i_pc_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> int: """Choose a position in an array as if it has been broadcast against columns. Can use rotational indexing along columns. !!! note Array must be two-dimensional.""" if arr.shape[1] == 1: flex_col = 0 else: flex_col = col if rotate_cols: return int(flex_col) % arr.shape[1] return int(flex_col) @register_jitted(cache=True) def flex_select_1d_pc_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Scalar: """Select an element of an array as if it has been broadcast against columns. Can use rotational indexing along columns. !!! note Array must be one-dimensional.""" flex_col = flex_choose_i_pc_1d_nb(arr, col, rotate_cols=rotate_cols) return arr[flex_col] @register_jitted(cache=True) def flex_choose_i_and_col_nb( arr: tp.FlexArray2d, i: int, col: int, rotate_rows: bool = _rotate_rows, rotate_cols: bool = _rotate_cols, ) -> tp.Tuple[int, int]: """Choose a position in an array as if it has been broadcast rows and columns. Can use rotational indexing along rows and columns. !!! note Array must be two-dimensional.""" if arr.shape[0] == 1: flex_i = 0 else: flex_i = i if arr.shape[1] == 1: flex_col = 0 else: flex_col = col if rotate_rows and rotate_cols: return int(flex_i) % arr.shape[0], int(flex_col) % arr.shape[1] if rotate_rows: return int(flex_i) % arr.shape[0], int(flex_col) if rotate_cols: return int(flex_i), int(flex_col) % arr.shape[1] return int(flex_i), int(flex_col) @register_jitted(cache=True) def flex_select_nb( arr: tp.FlexArray2d, i: int, col: int, rotate_rows: bool = _rotate_rows, rotate_cols: bool = _rotate_cols, ) -> tp.Scalar: """Select element of an array as if it has been broadcast rows and columns. Can use rotational indexing along rows and columns. !!! note Array must be two-dimensional.""" flex_i, flex_col = flex_choose_i_and_col_nb( arr, i, col, rotate_rows=rotate_rows, rotate_cols=rotate_cols, ) return arr[flex_i, flex_col] @register_jitted(cache=True) def flex_select_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array1d: """Select a row from a flexible 2-dim array. Returns a 1-dim array. !!! note Array must be two-dimensional.""" flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows) return arr[flex_i] @register_jitted(cache=True) def flex_select_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array1d: """Select a column from a flexible 2-dim array. Returns a 1-dim array. !!! note Array must be two-dimensional.""" flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols) return arr[:, flex_col] @register_jitted(cache=True) def flex_select_2d_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array2d: """Select a row from a flexible 2-dim array. Returns a 2-dim array. !!! note Array must be two-dimensional.""" flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows) return arr[flex_i : flex_i + 1] @register_jitted(cache=True) def flex_select_2d_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array2d: """Select a column from a flexible 2-dim array. Returns a 2-dim array. !!! note Array must be two-dimensional.""" flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols) return arr[:, flex_col : flex_col + 1] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Functions for working with indexes: index and columns. They perform operations on index objects, such as stacking, combining, and cleansing MultiIndex levels. !!! note "Index" in pandas context is referred to both index and columns.""" from datetime import datetime, timedelta import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.registries.jit_registry import jit_reg, register_jitted from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.base import Base __all__ = [ "ExceptLevel", "repeat_index", "tile_index", "stack_indexes", "combine_indexes", ] @define class ExceptLevel(DefineMixin): """Class for grouping except one or more levels.""" value: tp.MaybeLevelSequence = define.field() """One or more level positions or names.""" def to_any_index(index_like: tp.IndexLike) -> tp.Index: """Convert any index-like object to an index. Index objects are kept as-is.""" if checks.is_np_array(index_like) and index_like.ndim == 0: index_like = index_like[None] if not checks.is_index(index_like): return pd.Index(index_like) return index_like def get_index(obj: tp.SeriesFrame, axis: int) -> tp.Index: """Get index of `obj` by `axis`.""" checks.assert_instance_of(obj, (pd.Series, pd.DataFrame)) checks.assert_in(axis, (0, 1)) if axis == 0: return obj.index else: if checks.is_series(obj): if obj.name is not None: return pd.Index([obj.name]) return pd.Index([0]) # same as how pandas does it else: return obj.columns def index_from_values( values: tp.Sequence, single_value: bool = False, name: tp.Optional[tp.Hashable] = None, ) -> tp.Index: """Create a new `pd.Index` with `name` by parsing an iterable `values`. Each in `values` will correspond to an element in the new index.""" scalar_types = (int, float, complex, str, bool, datetime, timedelta, np.generic) type_id_number = {} value_names = [] if len(values) == 1: single_value = True for i in range(len(values)): if i > 0 and single_value: break v = values[i] if v is None or isinstance(v, scalar_types): value_names.append(v) elif isinstance(v, np.ndarray): all_same = False if np.issubdtype(v.dtype, np.floating): if np.isclose(v, v.item(0), equal_nan=True).all(): all_same = True elif v.dtype.names is not None: all_same = False else: if np.equal(v, v.item(0)).all(): all_same = True if all_same: value_names.append(v.item(0)) else: if single_value: value_names.append("array") else: if "array" not in type_id_number: type_id_number["array"] = {} if id(v) not in type_id_number["array"]: type_id_number["array"][id(v)] = len(type_id_number["array"]) value_names.append("array_%d" % (type_id_number["array"][id(v)])) else: type_name = str(type(v).__name__) if single_value: value_names.append("%s" % type_name) else: if type_name not in type_id_number: type_id_number[type_name] = {} if id(v) not in type_id_number[type_name]: type_id_number[type_name][id(v)] = len(type_id_number[type_name]) value_names.append("%s_%d" % (type_name, type_id_number[type_name][id(v)])) if single_value and len(values) > 1: value_names *= len(values) return pd.Index(value_names, name=name) def repeat_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index: """Repeat each element in `index` `n` times. Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`.""" from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if ignore_ranges is None: ignore_ranges = broadcasting_cfg["ignore_ranges"] index = to_any_index(index) if n == 1: return index if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name return pd.RangeIndex(start=0, stop=len(index) * n, step=1) return index.repeat(n) def tile_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index: """Tile the whole `index` `n` times. Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`.""" from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if ignore_ranges is None: ignore_ranges = broadcasting_cfg["ignore_ranges"] index = to_any_index(index) if n == 1: return index if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name return pd.RangeIndex(start=0, stop=len(index) * n, step=1) if isinstance(index, pd.MultiIndex): return pd.MultiIndex.from_tuples(np.tile(index, n), names=index.names) return pd.Index(np.tile(index, n), name=index.name) def clean_index( index: tp.IndexLike, drop_duplicates: tp.Optional[bool] = None, keep: tp.Optional[str] = None, drop_redundant: tp.Optional[bool] = None, ) -> tp.Index: """Clean index. Set `drop_duplicates` to True to remove duplicate levels. For details on `keep`, see `drop_duplicate_levels`. Set `drop_redundant` to True to use `drop_redundant_levels`.""" from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if drop_duplicates is None: drop_duplicates = broadcasting_cfg["drop_duplicates"] if keep is None: keep = broadcasting_cfg["keep"] if drop_redundant is None: drop_redundant = broadcasting_cfg["drop_redundant"] index = to_any_index(index) if drop_duplicates: index = drop_duplicate_levels(index, keep=keep) if drop_redundant: index = drop_redundant_levels(index) return index def stack_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **clean_index_kwargs) -> tp.Index: """Stack each index in `indexes` on top of each other, from top to bottom.""" if len(indexes) == 1: indexes = indexes[0] indexes = list(indexes) levels = [] for i in range(len(indexes)): index = indexes[i] if not isinstance(index, pd.MultiIndex): levels.append(to_any_index(index)) else: for j in range(index.nlevels): levels.append(index.get_level_values(j)) max_len = max(map(len, levels)) for i in range(len(levels)): if len(levels[i]) < max_len: if len(levels[i]) != 1: raise ValueError(f"Index at level {i} could not be broadcast to shape ({max_len},) ") levels[i] = repeat_index(levels[i], max_len, ignore_ranges=False) new_index = pd.MultiIndex.from_arrays(levels) return clean_index(new_index, **clean_index_kwargs) def combine_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **kwargs) -> tp.Index: """Combine each index in `indexes` using Cartesian product. Keyword arguments will be passed to `stack_indexes`.""" if len(indexes) == 1: indexes = indexes[0] indexes = list(indexes) new_index = to_any_index(indexes[0]) for i in range(1, len(indexes)): index1, index2 = new_index, to_any_index(indexes[i]) new_index1 = repeat_index(index1, len(index2), ignore_ranges=False) new_index2 = tile_index(index2, len(index1), ignore_ranges=False) new_index = stack_indexes([new_index1, new_index2], **kwargs) return new_index def combine_index_with_keys(index: tp.IndexLike, keys: tp.IndexLike, lens: tp.Sequence[int], **kwargs) -> tp.Index: """Build keys based on index lengths.""" if not isinstance(index, pd.Index): index = pd.Index(index) if not isinstance(keys, pd.Index): keys = pd.Index(keys) new_index = None new_keys = None start_idx = 0 for i in range(len(keys)): _index = index[start_idx : start_idx + lens[i]] if new_index is None: new_index = _index else: new_index = new_index.append(_index) start_idx += lens[i] new_key = keys[[i]].repeat(lens[i]) if new_keys is None: new_keys = new_key else: new_keys = new_keys.append(new_key) return stack_indexes([new_keys, new_index], **kwargs) def concat_indexes( *indexes: tp.MaybeTuple[tp.IndexLike], index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append", keys: tp.Optional[tp.IndexLike] = None, clean_index_kwargs: tp.KwargsLike = None, verify_integrity: bool = True, axis: int = 1, ) -> tp.Index: """Concatenate indexes. The following index concatenation methods are supported: * 'append': append one index to another * 'union': build a union of indexes * 'pd_concat': convert indexes to Pandas Series or DataFrames and use `pd.concat` * 'factorize': factorize the concatenated index * 'factorize_each': factorize each index and concatenate while keeping numbers unique * 'reset': reset the concatenated index without applying `keys` * Callable: a custom callable that takes the indexes and returns the concatenated index Argument `index_concat_method` also accepts a tuple of two options: the second option gets applied if the first one fails. Use `keys` as an index with the same number of elements as there are indexes to add another index level on top of the concatenated indexes. If `verify_integrity` is True and `keys` is None, performs various checks depending on the axis.""" if len(indexes) == 1: indexes = indexes[0] indexes = list(indexes) if clean_index_kwargs is None: clean_index_kwargs = {} if axis == 0: factorized_name = "row_idx" elif axis == 1: factorized_name = "col_idx" else: factorized_name = "group_idx" if keys is None: all_ranges = True for index in indexes: if not checks.is_default_index(index): all_ranges = False break if all_ranges: return pd.RangeIndex(stop=sum(map(len, indexes))) if isinstance(index_concat_method, tuple): try: return concat_indexes( *indexes, index_concat_method=index_concat_method[0], keys=keys, clean_index_kwargs=clean_index_kwargs, verify_integrity=verify_integrity, axis=axis, ) except Exception as e: return concat_indexes( *indexes, index_concat_method=index_concat_method[1], keys=keys, clean_index_kwargs=clean_index_kwargs, verify_integrity=verify_integrity, axis=axis, ) if not isinstance(index_concat_method, str): new_index = index_concat_method(indexes) elif index_concat_method.lower() == "append": new_index = None for index in indexes: if new_index is None: new_index = index else: new_index = new_index.append(index) elif index_concat_method.lower() == "union": if keys is not None: raise ValueError("Cannot apply keys after concatenating indexes through union") new_index = None for index in indexes: if new_index is None: new_index = index else: new_index = new_index.union(index) elif index_concat_method.lower() == "pd_concat": new_index = None for index in indexes: if isinstance(index, pd.MultiIndex): index = index.to_frame().reset_index(drop=True) else: index = index.to_series().reset_index(drop=True) if new_index is None: new_index = index else: if isinstance(new_index, pd.DataFrame): if isinstance(index, pd.Series): index = index.to_frame() elif isinstance(index, pd.Series): if isinstance(new_index, pd.DataFrame): new_index = new_index.to_frame() new_index = pd.concat((new_index, index), ignore_index=True) if isinstance(new_index, pd.Series): new_index = pd.Index(new_index) else: new_index = pd.MultiIndex.from_frame(new_index) elif index_concat_method.lower() == "factorize": new_index = concat_indexes( *indexes, index_concat_method="append", clean_index_kwargs=clean_index_kwargs, verify_integrity=False, axis=axis, ) new_index = pd.Index(pd.factorize(new_index)[0], name=factorized_name) elif index_concat_method.lower() == "factorize_each": new_index = None for index in indexes: index = pd.Index(pd.factorize(index)[0], name=factorized_name) if new_index is None: new_index = index next_min = index.max() + 1 else: new_index = new_index.append(index + next_min) next_min = index.max() + 1 + next_min elif index_concat_method.lower() == "reset": return pd.RangeIndex(stop=sum(map(len, indexes))) else: if axis == 0: raise ValueError(f"Invalid index concatenation method: '{index_concat_method}'") elif axis == 1: raise ValueError(f"Invalid column concatenation method: '{index_concat_method}'") else: raise ValueError(f"Invalid group concatenation method: '{index_concat_method}'") if keys is not None: if isinstance(keys[0], pd.Index): keys = concat_indexes( *keys, index_concat_method="append", clean_index_kwargs=clean_index_kwargs, verify_integrity=False, axis=axis, ) new_index = stack_indexes((keys, new_index), **clean_index_kwargs) keys = None elif not isinstance(keys, pd.Index): keys = pd.Index(keys) if keys is not None: top_index = None for i, index in enumerate(indexes): repeated_index = repeat_index(keys[[i]], len(index)) if top_index is None: top_index = repeated_index else: top_index = top_index.append(repeated_index) new_index = stack_indexes((top_index, new_index), **clean_index_kwargs) if verify_integrity: if keys is None: if axis == 0: if not new_index.is_monotonic_increasing: raise ValueError("Concatenated index is not monotonically increasing") if "mixed" in new_index.inferred_type: raise ValueError("Concatenated index is mixed") if new_index.has_duplicates: raise ValueError("Concatenated index contains duplicates") if axis == 1: if new_index.has_duplicates: raise ValueError("Concatenated columns contain duplicates") if axis == 2: if new_index.has_duplicates: len_sum = 0 for index in indexes: if len_sum > 0: prev_index = new_index[:len_sum] this_index = new_index[len_sum : len_sum + len(index)] if len(prev_index.intersection(this_index)) > 0: raise ValueError("Concatenated groups contain duplicates") len_sum += len(index) return new_index def drop_levels( index: tp.Index, levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence], strict: bool = True, ) -> tp.Index: """Drop `levels` in `index` by their name(s)/position(s). Provide `levels` as an instance of `ExceptLevel` to drop everything apart from the specified levels.""" if not isinstance(index, pd.MultiIndex): if strict: raise TypeError("Index must be a multi-index") return index if isinstance(levels, ExceptLevel): levels = levels.value except_mode = True else: except_mode = False levels_to_drop = set() if isinstance(levels, str) or not checks.is_sequence(levels): levels = [levels] for level in levels: if level in index.names: for level_pos in [i for i, x in enumerate(index.names) if x == level]: levels_to_drop.add(level_pos) elif checks.is_int(level): if level < 0: new_level = index.nlevels + level if new_level < 0: raise KeyError(f"Level at position {level} not found") level = new_level if 0 <= level < index.nlevels: levels_to_drop.add(level) else: raise KeyError(f"Level at position {level} not found") elif strict: raise KeyError(f"Level '{level}' not found") if except_mode: levels_to_drop = set(range(index.nlevels)).difference(levels_to_drop) if len(levels_to_drop) == 0: if strict: raise ValueError("No levels to drop") return index if len(levels_to_drop) >= index.nlevels: if strict: raise ValueError( f"Cannot remove {len(levels_to_drop)} levels from an index with {index.nlevels} levels: " "at least one level must be left" ) return index return index.droplevel(list(levels_to_drop)) def rename_levels(index: tp.Index, mapper: tp.MaybeMappingSequence[tp.Level], strict: bool = True) -> tp.Index: """Rename levels in `index` by `mapper`. Mapper can be a single or multiple levels to rename to, or a dictionary that maps old level names to new level names.""" if isinstance(index, pd.MultiIndex): nlevels = index.nlevels if isinstance(mapper, (int, str)): mapper = dict(zip(index.names, [mapper])) elif checks.is_complex_sequence(mapper): mapper = dict(zip(index.names, mapper)) else: nlevels = 1 if isinstance(mapper, (int, str)): mapper = dict(zip([index.name], [mapper])) elif checks.is_complex_sequence(mapper): mapper = dict(zip([index.name], mapper)) for k, v in mapper.items(): if k in index.names: if isinstance(index, pd.MultiIndex): index = index.rename(v, level=k) else: index = index.rename(v) elif checks.is_int(k): if k < 0: new_k = nlevels + k if new_k < 0: raise KeyError(f"Level at position {k} not found") k = new_k if 0 <= k < nlevels: if isinstance(index, pd.MultiIndex): index = index.rename(v, level=k) else: index = index.rename(v) else: raise KeyError(f"Level at position {k} not found") elif strict: raise KeyError(f"Level '{k}' not found") return index def select_levels( index: tp.Index, levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence], strict: bool = True, ) -> tp.Index: """Build a new index by selecting one or multiple `levels` from `index`. Provide `levels` as an instance of `ExceptLevel` to select everything apart from the specified levels.""" was_multiindex = True if not isinstance(index, pd.MultiIndex): was_multiindex = False index = pd.MultiIndex.from_arrays([index]) if isinstance(levels, ExceptLevel): levels = levels.value except_mode = True else: except_mode = False levels_to_select = list() if isinstance(levels, str) or not checks.is_sequence(levels): levels = [levels] single_mode = True else: single_mode = False for level in levels: if level in index.names: for level_pos in [i for i, x in enumerate(index.names) if x == level]: if level_pos not in levels_to_select: levels_to_select.append(level_pos) elif checks.is_int(level): if level < 0: new_level = index.nlevels + level if new_level < 0: raise KeyError(f"Level at position {level} not found") level = new_level if 0 <= level < index.nlevels: if level not in levels_to_select: levels_to_select.append(level) else: raise KeyError(f"Level at position {level} not found") elif strict: raise KeyError(f"Level '{level}' not found") if except_mode: levels_to_select = list(set(range(index.nlevels)).difference(levels_to_select)) if len(levels_to_select) == 0: if strict: raise ValueError("No levels to select") if not was_multiindex: return index.get_level_values(0) return index if len(levels_to_select) == 1 and single_mode: return index.get_level_values(levels_to_select[0]) levels = [index.get_level_values(level) for level in levels_to_select] return pd.MultiIndex.from_arrays(levels) def drop_redundant_levels(index: tp.Index) -> tp.Index: """Drop levels in `index` that either have a single unnamed value or a range from 0 to n.""" if not isinstance(index, pd.MultiIndex): return index levels_to_drop = [] for i in range(index.nlevels): if len(index.levels[i]) == 1 and index.levels[i].name is None: levels_to_drop.append(i) elif checks.is_default_index(index.get_level_values(i)): levels_to_drop.append(i) if len(levels_to_drop) < index.nlevels: return index.droplevel(levels_to_drop) return index def drop_duplicate_levels(index: tp.Index, keep: tp.Optional[str] = None) -> tp.Index: """Drop levels in `index` with the same name and values. Set `keep` to 'last' to keep last levels, otherwise 'first'. Set `keep` to None to use the default.""" from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if keep is None: keep = broadcasting_cfg["keep"] if not isinstance(index, pd.MultiIndex): return index checks.assert_in(keep.lower(), ["first", "last"]) levels_to_drop = set() level_values = [index.get_level_values(i) for i in range(index.nlevels)] for i in range(index.nlevels): level1 = level_values[i] for j in range(i + 1, index.nlevels): level2 = level_values[j] if level1.name is None or level2.name is None or level1.name == level2.name: if checks.is_index_equal(level1, level2, check_names=False): if level1.name is None and level2.name is not None: levels_to_drop.add(i) elif level1.name is not None and level2.name is None: levels_to_drop.add(j) else: if keep.lower() == "first": levels_to_drop.add(j) else: levels_to_drop.add(i) return index.droplevel(list(levels_to_drop)) @register_jitted(cache=True) def align_arr_indices_nb(a: tp.Array1d, b: tp.Array1d) -> tp.Array1d: """Return indices required to align `a` to `b`.""" idxs = np.empty(b.shape[0], dtype=int_) g = 0 for i in range(b.shape[0]): for j in range(a.shape[0]): if b[i] == a[j]: idxs[g] = j g += 1 break return idxs def align_index_to(index1: tp.Index, index2: tp.Index, jitted: tp.JittedOption = None) -> tp.IndexSlice: """Align `index1` to have the same shape as `index2` if they have any levels in common. Returns index slice for the aligning.""" if not isinstance(index1, pd.MultiIndex): index1 = pd.MultiIndex.from_arrays([index1]) if not isinstance(index2, pd.MultiIndex): index2 = pd.MultiIndex.from_arrays([index2]) if checks.is_index_equal(index1, index2): return pd.IndexSlice[:] if len(index1) > len(index2): raise ValueError("Longer index cannot be aligned to shorter index") mapper = {} for i in range(index1.nlevels): name1 = index1.names[i] for j in range(index2.nlevels): name2 = index2.names[j] if name1 is None or name2 is None or name1 == name2: if set(index2.levels[j]).issubset(set(index1.levels[i])): if i in mapper: raise ValueError(f"There are multiple candidate levels with name {name1} in second index") mapper[i] = j continue if name1 == name2 and name1 is not None: raise ValueError(f"Level {name1} in second index contains values not in first index") if len(mapper) == 0: if len(index1) == len(index2): return pd.IndexSlice[:] raise ValueError("Cannot find common levels to align indexes") factorized = [] for k, v in mapper.items(): factorized.append( pd.factorize( pd.concat( ( index1.get_level_values(k).to_series(), index2.get_level_values(v).to_series(), ) ) )[0], ) stacked = np.transpose(np.stack(factorized)) indices1 = stacked[: len(index1)] indices2 = stacked[len(index1) :] if len(indices1) < len(indices2): if len(np.unique(indices1, axis=0)) != len(indices1): raise ValueError("Cannot align indexes") if len(index2) % len(index1) == 0: tile_times = len(index2) // len(index1) index1_tiled = np.tile(indices1, (tile_times, 1)) if np.array_equal(index1_tiled, indices2): return pd.IndexSlice[np.tile(np.arange(len(index1)), tile_times)] unique_indices = np.unique(stacked, axis=0, return_inverse=True)[1] unique1 = unique_indices[: len(index1)] unique2 = unique_indices[len(index1) :] if len(indices1) == len(indices2): if np.array_equal(unique1, unique2): return pd.IndexSlice[:] func = jit_reg.resolve_option(align_arr_indices_nb, jitted) return pd.IndexSlice[func(unique1, unique2)] def align_indexes( *indexes: tp.MaybeTuple[tp.Index], return_new_index: bool = False, **kwargs, ) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]: """Align multiple indexes to each other with `align_index_to`.""" if len(indexes) == 1: indexes = indexes[0] indexes = list(indexes) index_items = sorted([(i, indexes[i]) for i in range(len(indexes))], key=lambda x: len(x[1])) index_slices = [] for i in range(len(index_items)): index_slice = align_index_to(index_items[i][1], index_items[-1][1], **kwargs) index_slices.append((index_items[i][0], index_slice)) index_slices = list(map(lambda x: x[1], sorted(index_slices, key=lambda x: x[0]))) if return_new_index: new_index = stack_indexes( *[indexes[i][index_slices[i]] for i in range(len(indexes))], drop_duplicates=True, ) return tuple(index_slices), new_index return tuple(index_slices) @register_jitted(cache=True) def block_index_product_nb( block_group_map1: tp.GroupMap, block_group_map2: tp.GroupMap, factorized1: tp.Array1d, factorized2: tp.Array1d, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Return indices required for building a block-wise Cartesian product of two factorized indexes.""" group_idxs1, group_lens1 = block_group_map1 group_idxs2, group_lens2 = block_group_map2 group_start_idxs1 = np.cumsum(group_lens1) - group_lens1 group_start_idxs2 = np.cumsum(group_lens2) - group_lens2 matched1 = np.empty(len(factorized1), dtype=np.bool_) matched2 = np.empty(len(factorized2), dtype=np.bool_) indices1 = np.empty(len(factorized1) * len(factorized2), dtype=int_) indices2 = np.empty(len(factorized1) * len(factorized2), dtype=int_) k1 = 0 k2 = 0 for g1 in range(len(group_lens1)): group_len1 = group_lens1[g1] group_start1 = group_start_idxs1[g1] for g2 in range(len(group_lens2)): group_len2 = group_lens2[g2] group_start2 = group_start_idxs2[g2] for c1 in range(group_len1): i = group_idxs1[group_start1 + c1] for c2 in range(group_len2): j = group_idxs2[group_start2 + c2] if factorized1[i] == factorized2[j]: matched1[i] = True matched2[j] = True indices1[k1] = i indices2[k2] = j k1 += 1 k2 += 1 if not np.all(matched1) or not np.all(matched2): raise ValueError("Cannot match some block level values") return indices1[:k1], indices2[:k2] def cross_index_with( index1: tp.Index, index2: tp.Index, return_new_index: bool = False, ) -> tp.Union[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Tuple[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Index]]: """Build a Cartesian product of one index with another while taking into account levels they have in common. Returns index slices for the aligning.""" from vectorbtpro.base.grouping.nb import get_group_map_nb index1_default = checks.is_default_index(index1, check_names=True) index2_default = checks.is_default_index(index2, check_names=True) if not isinstance(index1, pd.MultiIndex): index1 = pd.MultiIndex.from_arrays([index1]) if not isinstance(index2, pd.MultiIndex): index2 = pd.MultiIndex.from_arrays([index2]) if not index1_default and not index2_default and checks.is_index_equal(index1, index2): if return_new_index: new_index = stack_indexes(index1, index2, drop_duplicates=True) return (pd.IndexSlice[:], pd.IndexSlice[:]), new_index return pd.IndexSlice[:], pd.IndexSlice[:] levels1 = [] levels2 = [] for i in range(index1.nlevels): if checks.is_default_index(index1.get_level_values(i), check_names=True): continue for j in range(index2.nlevels): if checks.is_default_index(index2.get_level_values(j), check_names=True): continue name1 = index1.names[i] name2 = index2.names[j] if name1 == name2: if set(index2.levels[j]) == set(index1.levels[i]): if i in levels1 or j in levels2: raise ValueError(f"There are multiple candidate block levels with name {name1}") levels1.append(i) levels2.append(j) continue if name1 is not None: raise ValueError(f"Candidate block level {name1} in both indexes has different values") if len(levels1) == 0: # Regular index product indices1 = np.repeat(np.arange(len(index1)), len(index2)) indices2 = np.tile(np.arange(len(index2)), len(index1)) else: # Block index product index_levels1 = select_levels(index1, levels1) index_levels2 = select_levels(index2, levels2) block_levels1 = list(set(range(index1.nlevels)).difference(levels1)) block_levels2 = list(set(range(index2.nlevels)).difference(levels2)) if len(block_levels1) > 0: index_block_levels1 = select_levels(index1, block_levels1) else: index_block_levels1 = pd.Index(np.full(len(index1), 0)) if len(block_levels2) > 0: index_block_levels2 = select_levels(index2, block_levels2) else: index_block_levels2 = pd.Index(np.full(len(index2), 0)) factorized = pd.factorize(pd.concat((index_levels1.to_series(), index_levels2.to_series())))[0] factorized1 = factorized[: len(index_levels1)] factorized2 = factorized[len(index_levels1) :] block_factorized1, block_unique1 = pd.factorize(index_block_levels1) block_factorized2, block_unique2 = pd.factorize(index_block_levels2) block_group_map1 = get_group_map_nb(block_factorized1, len(block_unique1)) block_group_map2 = get_group_map_nb(block_factorized2, len(block_unique2)) indices1, indices2 = block_index_product_nb( block_group_map1, block_group_map2, factorized1, factorized2, ) if return_new_index: new_index = stack_indexes(index1[indices1], index2[indices2], drop_duplicates=True) return (pd.IndexSlice[indices1], pd.IndexSlice[indices2]), new_index return pd.IndexSlice[indices1], pd.IndexSlice[indices2] def cross_indexes( *indexes: tp.MaybeTuple[tp.Index], return_new_index: bool = False, ) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]: """Cross multiple indexes with `cross_index_with`.""" if len(indexes) == 1: indexes = indexes[0] indexes = list(indexes) if len(indexes) == 2: return cross_index_with(indexes[0], indexes[1], return_new_index=return_new_index) index = None index_slices = [] for i in range(len(indexes) - 2, -1, -1): index1 = indexes[i] if i == len(indexes) - 2: index2 = indexes[i + 1] else: index2 = index (index_slice1, index_slice2), index = cross_index_with(index1, index2, return_new_index=True) if i == len(indexes) - 2: index_slices.append(index_slice2) else: for j in range(len(index_slices)): if isinstance(index_slices[j], slice): index_slices[j] = np.arange(len(index2))[index_slices[j]] index_slices[j] = index_slices[j][index_slice2] index_slices.append(index_slice1) if return_new_index: return tuple(index_slices[::-1]), index return tuple(index_slices[::-1]) OptionalLevelSequence = tp.Optional[tp.Sequence[tp.Union[None, tp.Level]]] def pick_levels( index: tp.Index, required_levels: OptionalLevelSequence = None, optional_levels: OptionalLevelSequence = None, ) -> tp.Tuple[tp.List[int], tp.List[int]]: """Pick optional and required levels and return their indices. Raises an exception if index has less or more levels than expected.""" if required_levels is None: required_levels = [] if optional_levels is None: optional_levels = [] checks.assert_instance_of(index, pd.MultiIndex) n_opt_set = len(list(filter(lambda x: x is not None, optional_levels))) n_req_set = len(list(filter(lambda x: x is not None, required_levels))) n_levels_left = index.nlevels - n_opt_set if n_req_set < len(required_levels): if n_levels_left != len(required_levels): n_expected = len(required_levels) + n_opt_set raise ValueError(f"Expected {n_expected} levels, found {index.nlevels}") levels_left = list(range(index.nlevels)) _optional_levels = [] for level in optional_levels: level_pos = None if level is not None: checks.assert_instance_of(level, (int, str)) if isinstance(level, str): level_pos = index.names.index(level) else: level_pos = level if level_pos < 0: level_pos = index.nlevels + level_pos levels_left.remove(level_pos) _optional_levels.append(level_pos) _required_levels = [] for level in required_levels: level_pos = None if level is not None: checks.assert_instance_of(level, (int, str)) if isinstance(level, str): level_pos = index.names.index(level) else: level_pos = level if level_pos < 0: level_pos = index.nlevels + level_pos levels_left.remove(level_pos) _required_levels.append(level_pos) for i, level in enumerate(_required_levels): if level is None: _required_levels[i] = levels_left.pop(0) return _required_levels, _optional_levels def find_first_occurrence(index_value: tp.Any, index: tp.Index) -> int: """Return index of the first occurrence in `index`.""" loc = index.get_loc(index_value) if isinstance(loc, slice): return loc.start elif isinstance(loc, list): return loc[0] elif isinstance(loc, np.ndarray): return np.flatnonzero(loc)[0] return loc IndexApplierT = tp.TypeVar("IndexApplierT", bound="IndexApplier") class IndexApplier(Base): """Abstract class that can apply a function on an index.""" def apply_to_index(self: IndexApplierT, apply_func: tp.Callable, *args, **kwargs) -> IndexApplierT: """Apply function `apply_func` on the index of the instance and return a new instance.""" raise NotImplementedError def add_levels( self: IndexApplierT, *indexes: tp.Index, on_top: bool = True, drop_duplicates: tp.Optional[bool] = None, keep: tp.Optional[str] = None, drop_redundant: tp.Optional[bool] = None, **kwargs, ) -> IndexApplierT: """Append or prepend levels using `stack_indexes`. Set `on_top` to False to stack at bottom. See `IndexApplier.apply_to_index` for other keyword arguments.""" def _apply_func(index): if on_top: return stack_indexes( [*indexes, index], drop_duplicates=drop_duplicates, keep=keep, drop_redundant=drop_redundant, ) return stack_indexes( [index, *indexes], drop_duplicates=drop_duplicates, keep=keep, drop_redundant=drop_redundant, ) return self.apply_to_index(_apply_func, **kwargs) def drop_levels( self: IndexApplierT, levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence], strict: bool = True, **kwargs, ) -> IndexApplierT: """Drop levels using `drop_levels`. See `IndexApplier.apply_to_index` for other keyword arguments.""" def _apply_func(index): return drop_levels(index, levels, strict=strict) return self.apply_to_index(_apply_func, **kwargs) def rename_levels( self: IndexApplierT, mapper: tp.MaybeMappingSequence[tp.Level], strict: bool = True, **kwargs, ) -> IndexApplierT: """Rename levels using `rename_levels`. See `IndexApplier.apply_to_index` for other keyword arguments.""" def _apply_func(index): return rename_levels(index, mapper, strict=strict) return self.apply_to_index(_apply_func, **kwargs) def select_levels( self: IndexApplierT, level_names: tp.Union[ExceptLevel, tp.MaybeLevelSequence], strict: bool = True, **kwargs, ) -> IndexApplierT: """Select levels using `select_levels`. See `IndexApplier.apply_to_index` for other keyword arguments.""" def _apply_func(index): return select_levels(index, level_names, strict=strict) return self.apply_to_index(_apply_func, **kwargs) def drop_redundant_levels(self: IndexApplierT, **kwargs) -> IndexApplierT: """Drop any redundant levels using `drop_redundant_levels`. See `IndexApplier.apply_to_index` for other keyword arguments.""" def _apply_func(index): return drop_redundant_levels(index) return self.apply_to_index(_apply_func, **kwargs) def drop_duplicate_levels(self: IndexApplierT, keep: tp.Optional[str] = None, **kwargs) -> IndexApplierT: """Drop any duplicate levels using `drop_duplicate_levels`. See `IndexApplier.apply_to_index` for other keyword arguments.""" def _apply_func(index): return drop_duplicate_levels(index, keep=keep) return self.apply_to_index(_apply_func, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes and functions for indexing.""" import functools from datetime import time from functools import partial import numpy as np import pandas as pd from pandas.tseries.offsets import BaseOffset from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks, datetime_ as dt, datetime_nb as dt_nb from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import hdict, merge_dicts from vectorbtpro.utils.mapping import to_field_mapping from vectorbtpro.utils.pickling import pdict from vectorbtpro.utils.selection import PosSel, LabelSel from vectorbtpro.utils.template import CustomTemplate __all__ = [ "PandasIndexer", "ExtPandasIndexer", "hslice", "get_index_points", "get_index_ranges", "get_idxs", "index_dict", "IdxSetter", "IdxSetterFactory", "IdxDict", "IdxSeries", "IdxFrame", "IdxRecords", "posidx", "maskidx", "lbidx", "dtidx", "dtcidx", "pointidx", "rangeidx", "autoidx", "rowidx", "colidx", "idx", ] __pdoc__ = {} class IndexingError(Exception): """Exception raised when an indexing error has occurred.""" IndexingBaseT = tp.TypeVar("IndexingBaseT", bound="IndexingBase") class IndexingBase(Base): """Class that supports indexing through `IndexingBase.indexing_func`.""" def indexing_func(self: IndexingBaseT, pd_indexing_func: tp.Callable, **kwargs) -> IndexingBaseT: """Apply `pd_indexing_func` on all pandas objects in question and return a new instance of the class. Should be overridden.""" raise NotImplementedError def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None: """Apply `pd_indexing_setter_func` on all pandas objects in question. Should be overridden.""" raise NotImplementedError class LocBase(Base): """Class that implements location-based indexing.""" def __init__( self, indexing_func: tp.Callable, indexing_setter_func: tp.Optional[tp.Callable] = None, **kwargs, ) -> None: self._indexing_func = indexing_func self._indexing_setter_func = indexing_setter_func self._indexing_kwargs = kwargs @property def indexing_func(self) -> tp.Callable: """Indexing function.""" return self._indexing_func @property def indexing_setter_func(self) -> tp.Optional[tp.Callable]: """Indexing setter function.""" return self._indexing_setter_func @property def indexing_kwargs(self) -> dict: """Keyword arguments passed to `LocBase.indexing_func`.""" return self._indexing_kwargs def __getitem__(self, key: tp.Any) -> tp.Any: raise NotImplementedError def __setitem__(self, key: tp.Any, value: tp.Any) -> None: raise NotImplementedError def __iter__(self): raise TypeError(f"'{type(self).__name__}' object is not iterable") class pdLoc(LocBase): """Forwards a Pandas-like indexing operation to each Series/DataFrame and returns a new class instance.""" @classmethod def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: """Pandas-like indexing operation.""" raise NotImplementedError @classmethod def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: """Pandas-like indexing setter operation.""" raise NotImplementedError def __getitem__(self, key: tp.Any) -> tp.Any: return self.indexing_func(partial(self.pd_indexing_func, key=key), **self.indexing_kwargs) def __setitem__(self, key: tp.Any, value: tp.Any) -> None: self.indexing_setter_func(partial(self.pd_indexing_setter_func, key=key, value=value), **self.indexing_kwargs) class iLoc(pdLoc): """Forwards `pd.Series.iloc`/`pd.DataFrame.iloc` operation to each Series/DataFrame and returns a new class instance.""" @classmethod def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: return obj.iloc.__getitem__(key) @classmethod def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: obj.iloc.__setitem__(key, value) class Loc(pdLoc): """Forwards `pd.Series.loc`/`pd.DataFrame.loc` operation to each Series/DataFrame and returns a new class instance.""" @classmethod def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: return obj.loc.__getitem__(key) @classmethod def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: obj.loc.__setitem__(key, value) PandasIndexerT = tp.TypeVar("PandasIndexerT", bound="PandasIndexer") class PandasIndexer(IndexingBase): """Implements indexing using `iloc`, `loc`, `xs` and `__getitem__`. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.indexing import PandasIndexer >>> class C(PandasIndexer): ... def __init__(self, df1, df2): ... self.df1 = df1 ... self.df2 = df2 ... super().__init__() ... ... def indexing_func(self, pd_indexing_func): ... return type(self)( ... pd_indexing_func(self.df1), ... pd_indexing_func(self.df2) ... ) >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) >>> df2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]}) >>> c = C(df1, df2) >>> c.iloc[:, 0] <__main__.C object at 0x1a1cacbbe0> >>> c.iloc[:, 0].df1 0 1 1 2 Name: a, dtype: int64 >>> c.iloc[:, 0].df2 0 5 1 6 Name: a, dtype: int64 ``` """ def __init__(self, **kwargs) -> None: self._iloc = iLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs) self._loc = Loc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs) self._indexing_kwargs = kwargs @property def indexing_kwargs(self) -> dict: """Indexing keyword arguments.""" return self._indexing_kwargs @property def iloc(self) -> iLoc: """Purely integer-location based indexing for selection by position.""" return self._iloc iloc.__doc__ = iLoc.__doc__ @property def loc(self) -> Loc: """Purely label-location based indexer for selection by label.""" return self._loc loc.__doc__ = Loc.__doc__ def xs(self: PandasIndexerT, *args, **kwargs) -> PandasIndexerT: """Forwards `pd.Series.xs`/`pd.DataFrame.xs` operation to each Series/DataFrame and returns a new class instance.""" return self.indexing_func(lambda x: x.xs(*args, **kwargs), **self.indexing_kwargs) def __getitem__(self: PandasIndexerT, key: tp.Any) -> PandasIndexerT: def __getitem__func(x, _key=key): return x.__getitem__(_key) return self.indexing_func(__getitem__func, **self.indexing_kwargs) def __setitem__(self, key: tp.Any, value: tp.Any) -> None: def __setitem__func(x, _key=key, _value=value): return x.__setitem__(_key, _value) self.indexing_setter_func(__setitem__func, **self.indexing_kwargs) def __iter__(self): raise TypeError(f"'{type(self).__name__}' object is not iterable") class xLoc(iLoc): """Subclass of `iLoc` that transforms an `Idxr`-based operation with `get_idxs` to an `iLoc` operation.""" @classmethod def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame: from vectorbtpro.base.indexes import get_index if isinstance(key, tuple): key = Idxr(*key) index = get_index(obj, 0) columns = get_index(obj, 1) freq = dt.infer_index_freq(index) row_idxs, col_idxs = get_idxs(key, index=index, columns=columns, freq=freq) if isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2: row_idxs = normalize_idxs(row_idxs, target_len=len(index)) if isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2: col_idxs = normalize_idxs(col_idxs, target_len=len(columns)) if isinstance(obj, pd.Series): if not isinstance(col_idxs, (slice, hslice)) or ( col_idxs.start is not None or col_idxs.stop is not None or col_idxs.step is not None ): raise IndexingError("Too many indexers") return obj.iloc.__getitem__(row_idxs) return obj.iloc.__getitem__((row_idxs, col_idxs)) @classmethod def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None: IdxSetter([(key, value)]).set_pd(obj) class ExtPandasIndexer(PandasIndexer): """Extension of `PandasIndexer` that also implements indexing using `xLoc`.""" def __init__(self, **kwargs) -> None: self._xloc = xLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs) PandasIndexer.__init__(self, **kwargs) @property def xloc(self) -> xLoc: """`Idxr`-based indexing.""" return self._xloc xloc.__doc__ = xLoc.__doc__ class ParamLoc(LocBase): """Access a group of columns by parameter using `pd.Series.loc`. Uses `mapper` to establish link between columns and parameter values.""" @classmethod def encode_key(cls, key: tp.Any): """Encode key.""" if isinstance(key, tuple): return str(tuple(map(lambda k: k.item() if isinstance(k, np.generic) else k, key))) key_str = str(key) return str(key.item()) if isinstance(key, np.generic) else key_str def __init__( self, mapper: tp.Series, indexing_func: tp.Callable, indexing_setter_func: tp.Optional[tp.Callable] = None, level_name: tp.Level = None, **kwargs, ) -> None: checks.assert_instance_of(mapper, pd.Series) if mapper.dtype == "O": if isinstance(mapper.iloc[0], tuple): mapper = mapper.apply(self.encode_key) else: mapper = mapper.astype(str) self._mapper = mapper self._level_name = level_name LocBase.__init__(self, indexing_func, indexing_setter_func=indexing_setter_func, **kwargs) @property def mapper(self) -> tp.Series: """Mapper.""" return self._mapper @property def level_name(self) -> tp.Level: """Level name.""" return self._level_name def get_idxs(self, key: tp.Any) -> tp.Array1d: """Get array of indices affected by this key.""" if self.mapper.dtype == "O": if isinstance(key, (slice, hslice)): start = self.encode_key(key.start) if key.start is not None else None stop = self.encode_key(key.stop) if key.stop is not None else None key = slice(start, stop, key.step) elif isinstance(key, (list, np.ndarray)): key = list(map(self.encode_key, key)) else: key = self.encode_key(key) mapper = pd.Series(np.arange(len(self.mapper.index)), index=self.mapper.values) idxs = mapper.loc.__getitem__(key) if isinstance(idxs, pd.Series): idxs = idxs.values return idxs def __getitem__(self, key: tp.Any) -> tp.Any: idxs = self.get_idxs(key) is_multiple = isinstance(key, (slice, hslice, list, np.ndarray)) def pd_indexing_func(obj: tp.SeriesFrame) -> tp.MaybeSeriesFrame: from vectorbtpro.base.indexes import drop_levels new_obj = obj.iloc[:, idxs] if not is_multiple: if self.level_name is not None: if checks.is_frame(new_obj): if isinstance(new_obj.columns, pd.MultiIndex): new_obj.columns = drop_levels(new_obj.columns, self.level_name) return new_obj return self.indexing_func(pd_indexing_func, **self.indexing_kwargs) def __setitem__(self, key: tp.Any, value: tp.Any) -> None: idxs = self.get_idxs(key) def pd_indexing_setter_func(obj: tp.SeriesFrame) -> None: obj.iloc[:, idxs] = value return self.indexing_setter_func(pd_indexing_setter_func, **self.indexing_kwargs) def indexing_on_mapper( mapper: tp.Series, ref_obj: tp.SeriesFrame, pd_indexing_func: tp.Callable, ) -> tp.Optional[tp.Series]: """Broadcast `mapper` Series to `ref_obj` and perform pandas indexing using `pd_indexing_func`.""" from vectorbtpro.base.reshaping import broadcast_to checks.assert_instance_of(mapper, pd.Series) checks.assert_instance_of(ref_obj, (pd.Series, pd.DataFrame)) if isinstance(ref_obj, pd.Series): range_mapper = broadcast_to(0, ref_obj) else: range_mapper = broadcast_to(np.arange(len(mapper.index))[None], ref_obj) loced_range_mapper = pd_indexing_func(range_mapper) new_mapper = mapper.iloc[loced_range_mapper.values[0]] if checks.is_frame(loced_range_mapper): return pd.Series(new_mapper.values, index=loced_range_mapper.columns, name=mapper.name) elif checks.is_series(loced_range_mapper): return pd.Series([new_mapper], index=[loced_range_mapper.name], name=mapper.name) return None def build_param_indexer( param_names: tp.Sequence[str], class_name: str = "ParamIndexer", module_name: tp.Optional[str] = None, ) -> tp.Type[IndexingBase]: """A factory to create a class with parameter indexing. Parameter indexer enables accessing a group of rows and columns by a parameter array (similar to `loc`). This way, one can query index/columns by another Series called a parameter mapper, which is just a `pd.Series` that maps columns (its index) to params (its values). Parameter indexing is important, since querying by column/index labels alone is not always the best option. For example, `pandas` doesn't let you query by list at a specific index/column level. Args: param_names (list of str): Names of the parameters. class_name (str): Name of the generated class. module_name (str): Name of the module to which the class should be bound. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.indexing import build_param_indexer, indexing_on_mapper >>> MyParamIndexer = build_param_indexer(['my_param']) >>> class C(MyParamIndexer): ... def __init__(self, df, param_mapper): ... self.df = df ... self._my_param_mapper = param_mapper ... super().__init__([param_mapper]) ... ... def indexing_func(self, pd_indexing_func): ... return type(self)( ... pd_indexing_func(self.df), ... indexing_on_mapper(self._my_param_mapper, self.df, pd_indexing_func) ... ) >>> df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) >>> param_mapper = pd.Series(['First', 'Second'], index=['a', 'b']) >>> c = C(df, param_mapper) >>> c.my_param_loc['First'].df 0 1 1 2 Name: a, dtype: int64 >>> c.my_param_loc['Second'].df 0 3 1 4 Name: b, dtype: int64 >>> c.my_param_loc[['First', 'First', 'Second', 'Second']].df a b 0 1 1 3 3 1 2 2 4 4 ``` """ class ParamIndexer(IndexingBase): """Class with parameter indexing.""" def __init__( self, param_mappers: tp.Sequence[tp.Series], level_names: tp.Optional[tp.LevelSequence] = None, **kwargs, ) -> None: checks.assert_len_equal(param_names, param_mappers) for i, param_name in enumerate(param_names): level_name = level_names[i] if level_names is not None else None _param_loc = ParamLoc(param_mappers[i], self.indexing_func, level_name=level_name, **kwargs) setattr(self, f"_{param_name}_loc", _param_loc) for i, param_name in enumerate(param_names): def param_loc(self, _param_name=param_name) -> ParamLoc: return getattr(self, f"_{_param_name}_loc") param_loc.__doc__ = f"""Access a group of columns by parameter `{param_name}` using `pd.Series.loc`. Forwards this operation to each Series/DataFrame and returns a new class instance. """ setattr(ParamIndexer, param_name + "_loc", property(param_loc)) ParamIndexer.__name__ = class_name ParamIndexer.__qualname__ = ParamIndexer.__name__ if module_name is not None: ParamIndexer.__module__ = module_name return ParamIndexer hsliceT = tp.TypeVar("hsliceT", bound="hslice") @define class hslice(DefineMixin): """Hashable slice.""" start: object = define.field() """Start.""" stop: object = define.field() """Stop.""" step: object = define.field() """Step.""" def __init__(self, start: object = MISSING, stop: object = MISSING, step: object = MISSING) -> None: if start is not MISSING and stop is MISSING and step is MISSING: stop = start start, step = None, None else: if start is MISSING: start = None if stop is MISSING: stop = None if step is MISSING: step = None DefineMixin.__init__(self, start=start, stop=stop, step=step) @classmethod def from_slice(cls: tp.Type[hsliceT], slice_: slice) -> hsliceT: """Construct from a slice.""" return cls(slice_.start, slice_.stop, slice_.step) def to_slice(self) -> slice: """Convert to a slice.""" return slice(self.start, self.stop, self.step) class IdxrBase(Base): """Abstract class for resolving indices.""" def get(self, *args, **kwargs) -> tp.Any: """Get indices.""" raise NotImplementedError @classmethod def slice_indexer( cls, index: tp.Index, slice_: tp.Slice, closed_start: bool = True, closed_end: bool = False, ) -> slice: """Compute the slice indexer for input labels and step.""" start = slice_.start end = slice_.stop if start is not None: left_start = index.get_slice_bound(start, side="left") right_start = index.get_slice_bound(start, side="right") if left_start == right_start or not closed_start: start = right_start else: start = left_start if end is not None: left_end = index.get_slice_bound(end, side="left") right_end = index.get_slice_bound(end, side="right") if left_end == right_end or closed_end: end = right_end else: end = left_end return slice(start, end, slice_.step) def check_idxs(self, idxs: tp.MaybeIndexArray, check_minus_one: bool = False) -> None: """Check indices after resolving them.""" if isinstance(idxs, slice): if idxs.start is not None and not checks.is_int(idxs.start): raise TypeError("Start of a returned index slice must be an integer or None") if idxs.stop is not None and not checks.is_int(idxs.stop): raise TypeError("Stop of a returned index slice must be an integer or None") if idxs.step is not None and not checks.is_int(idxs.step): raise TypeError("Step of a returned index slice must be an integer or None") if check_minus_one and idxs.start == -1: raise ValueError("Range start index couldn't be matched") elif check_minus_one and idxs.stop == -1: raise ValueError("Range end index couldn't be matched") elif checks.is_int(idxs): if check_minus_one and idxs == -1: raise ValueError("Index couldn't be matched") elif checks.is_sequence(idxs) and not np.isscalar(idxs): if len(idxs) == 0: raise ValueError("No indices could be matched") if not isinstance(idxs, np.ndarray): raise ValueError(f"Indices must be a NumPy array, not {type(idxs)}") if not np.issubdtype(idxs.dtype, np.integer) or np.issubdtype(idxs.dtype, np.bool_): raise ValueError(f"Indices must be of integer data type, not {idxs.dtype}") if check_minus_one and -1 in idxs: raise ValueError("Some indices couldn't be matched") if idxs.ndim not in (1, 2): raise ValueError("Indices array must have either 1 or 2 dimensions") if idxs.ndim == 2 and idxs.shape[1] != 2: raise ValueError("Indices array provided as ranges must have exactly two columns") else: raise TypeError( f"Indices must be an integer, a slice, a NumPy array, or a tuple of two NumPy arrays, not {type(idxs)}" ) def normalize_idxs(idxs: tp.MaybeIndexArray, target_len: int) -> tp.Array1d: """Normalize indexes into a 1-dim integer array.""" if isinstance(idxs, hslice): idxs = idxs.to_slice() if isinstance(idxs, slice): idxs = np.arange(target_len)[idxs] if checks.is_int(idxs): idxs = np.array([idxs]) if idxs.ndim == 2: from vectorbtpro.base.merging import concat_arrays idxs = concat_arrays(tuple(map(lambda x: np.arange(x[0], x[1]), idxs))) if (idxs < 0).any(): idxs = np.where(idxs >= 0, idxs, target_len + idxs) return idxs class UniIdxr(IdxrBase): """Abstract class for resolving indices based on a single index.""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: raise NotImplementedError def __invert__(self): def _op_func(x, index=None, freq=None): if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) idxs = np.setdiff1d(np.arange(len(index)), x) self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self) def __and__(self, other): def _op_func(x, y, index=None, freq=None): if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) y = normalize_idxs(y, len(index)) idxs = np.intersect1d(x, y) self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self, other) def __or__(self, other): def _op_func(x, y, index=None, freq=None): if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) y = normalize_idxs(y, len(index)) idxs = np.union1d(x, y) self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self, other) def __sub__(self, other): def _op_func(x, y, index=None, freq=None): if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) y = normalize_idxs(y, len(index)) idxs = np.setdiff1d(x, y) self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self, other) def __xor__(self, other): def _op_func(x, y, index=None, freq=None): if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) y = normalize_idxs(y, len(index)) idxs = np.setxor1d(x, y) self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self, other) def __lshift__(self, other): def _op_func(x, y, index=None, freq=None): if not checks.is_int(y): raise TypeError("Second operand in __lshift__ must be an integer") if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) shifted = x - y idxs = shifted[shifted >= 0] self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self, other) def __rshift__(self, other): def _op_func(x, y, index=None, freq=None): if not checks.is_int(y): raise TypeError("Second operand in __rshift__ must be an integer") if index is None: raise ValueError("Index is required") x = normalize_idxs(x, len(index)) shifted = x + y idxs = shifted[shifted >= 0] self.check_idxs(idxs) return idxs return UniIdxrOp(_op_func, self, other) @define class UniIdxrOp(UniIdxr, DefineMixin): """Class for applying an operation to one or more indexers. Produces a single set of indices.""" op_func: tp.Callable = define.field() """Operation function that takes the indices of each indexer (as `*args`), `index` (keyword argument), and `freq` (keyword argument), and returns new indices.""" idxrs: tp.Tuple[object, ...] = define.field() """A tuple of one or more indexers.""" def __init__(self, op_func: tp.Callable, *idxrs) -> None: if len(idxrs) == 1 and checks.is_iterable(idxrs[0]): idxrs = idxrs[0] DefineMixin.__init__(self, op_func=op_func, idxrs=idxrs) def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: idxr_indices = [] for idxr in self.idxrs: if isinstance(idxr, IdxrBase): checks.assert_instance_of(idxr, UniIdxr) idxr_indices.append(idxr.get(index=index, freq=freq)) else: idxr_indices.append(idxr) return self.op_func(*idxr_indices, index=index, freq=freq) @define class PosIdxr(UniIdxr, DefineMixin): """Class for resolving indices provided as integer positions.""" value: tp.Union[None, tp.MaybeSequence[tp.MaybeSequence[int]], tp.Slice] = define.field() """One or more integer positions.""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if self.value is None: return slice(None, None, None) idxs = self.value if checks.is_sequence(idxs) and not np.isscalar(idxs): idxs = np.asarray(idxs) if isinstance(idxs, hslice): idxs = idxs.to_slice() self.check_idxs(idxs) return idxs @define class MaskIdxr(UniIdxr, DefineMixin): """Class for resolving indices provided as a mask.""" value: tp.Union[None, tp.Sequence[bool]] = define.field() """Mask.""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if self.value is None: return slice(None, None, None) idxs = np.flatnonzero(self.value) self.check_idxs(idxs) return idxs @define class LabelIdxr(UniIdxr, DefineMixin): """Class for resolving indices provided as labels.""" value: tp.Union[None, tp.MaybeSequence[tp.Label], tp.Slice] = define.field() """One or more labels.""" closed_start: bool = define.field(default=True) """Whether slice start should be inclusive.""" closed_end: bool = define.field(default=True) """Whether slice end should be inclusive.""" level: tp.MaybeLevelSequence = define.field(default=None) """One or more levels.""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if self.value is None: return slice(None, None, None) if index is None: raise ValueError("Index is required") if self.level is not None: from vectorbtpro.base.indexes import select_levels index = select_levels(index, self.level) if isinstance(self.value, (slice, hslice)): idxs = self.slice_indexer( index, self.value, closed_start=self.closed_start, closed_end=self.closed_end, ) elif (checks.is_sequence(self.value) and not np.isscalar(self.value)) and ( not isinstance(index, pd.MultiIndex) or (isinstance(index, pd.MultiIndex) and isinstance(self.value[0], tuple)) ): idxs = index.get_indexer_for(self.value) else: idxs = index.get_loc(self.value) if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_): idxs = np.flatnonzero(idxs) self.check_idxs(idxs, check_minus_one=True) return idxs @define class DatetimeIdxr(UniIdxr, DefineMixin): """Class for resolving indices provided as datetime-like objects.""" value: tp.Union[None, tp.MaybeSequence[tp.DatetimeLike], tp.Slice] = define.field() """One or more datetime-like objects.""" closed_start: bool = define.field(default=True) """Whether slice start should be inclusive.""" closed_end: bool = define.field(default=False) """Whether slice end should be inclusive.""" indexer_method: tp.Optional[str] = define.field(default="bfill") """Method for `pd.Index.get_indexer`. Allows two additional values: "before" and "after".""" below_to_zero: bool = define.field(default=False) """Whether to place 0 instead of -1 if `DatetimeIdxr.value` is below the first index.""" above_to_len: bool = define.field(default=False) """Whether to place `len(index)` instead of -1 if `DatetimeIdxr.value` is above the last index.""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if self.value is None: return slice(None, None, None) if index is None: raise ValueError("Index is required") index = dt.prepare_dt_index(index) checks.assert_instance_of(index, pd.DatetimeIndex) if not index.is_unique: raise ValueError("Datetime index must be unique") if not index.is_monotonic_increasing: raise ValueError("Datetime index must be monotonically increasing") if isinstance(self.value, (slice, hslice)): start = dt.try_align_dt_to_index(self.value.start, index) stop = dt.try_align_dt_to_index(self.value.stop, index) new_value = slice(start, stop, self.value.step) idxs = self.slice_indexer(index, new_value, closed_start=self.closed_start, closed_end=self.closed_end) elif checks.is_sequence(self.value) and not np.isscalar(self.value): new_value = dt.try_align_to_dt_index(self.value, index) idxs = index.get_indexer(new_value, method=self.indexer_method) if self.below_to_zero: idxs = np.where(new_value < index[0], 0, idxs) if self.above_to_len: idxs = np.where(new_value > index[-1], len(index), idxs) else: new_value = dt.try_align_dt_to_index(self.value, index) if new_value < index[0] and self.below_to_zero: idxs = 0 elif new_value > index[-1] and self.above_to_len: idxs = len(index) else: if self.indexer_method is None or new_value in index: idxs = index.get_loc(new_value) if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_): idxs = np.flatnonzero(idxs) else: indexer_method = self.indexer_method if indexer_method is not None: indexer_method = indexer_method.lower() if indexer_method == "before": new_value = new_value - pd.Timedelta(1, "ns") indexer_method = "ffill" elif indexer_method == "after": new_value = new_value + pd.Timedelta(1, "ns") indexer_method = "bfill" idxs = index.get_indexer([new_value], method=indexer_method)[0] self.check_idxs(idxs, check_minus_one=True) return idxs @define class DTCIdxr(UniIdxr, DefineMixin): """Class for resolving indices provided as datetime-like components.""" value: tp.Union[None, tp.MaybeSequence[tp.DTCLike], tp.Slice] = define.field() """One or more datetime-like components.""" parse_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `vectorbtpro.utils.datetime_.DTC.parse`.""" closed_start: bool = define.field(default=True) """Whether slice start should be inclusive.""" closed_end: bool = define.field(default=False) """Whether slice end should be inclusive.""" jitted: tp.JittedOption = define.field(default=None) """Jitting option passed to `vectorbtpro.utils.datetime_nb.index_matches_dtc_nb` and `vectorbtpro.utils.datetime_nb.index_within_dtc_range_nb`.""" @staticmethod def get_dtc_namedtuple(value: tp.Optional[tp.DTCLike] = None, **parse_kwargs) -> dt.DTCNT: """Convert a value to a `vectorbtpro.utils.datetime_.DTCNT` instance.""" if value is None: return dt.DTC().to_namedtuple() if isinstance(value, dt.DTC): return value.to_namedtuple() if isinstance(value, dt.DTCNT): return value return dt.DTC.parse(value, **parse_kwargs).to_namedtuple() def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if self.value is None: return slice(None, None, None) parse_kwargs = self.parse_kwargs if parse_kwargs is None: parse_kwargs = {} if index is None: raise ValueError("Index is required") index = dt.prepare_dt_index(index) ns_index = dt.to_ns(index) checks.assert_instance_of(index, pd.DatetimeIndex) if not index.is_unique: raise ValueError("Datetime index must be unique") if not index.is_monotonic_increasing: raise ValueError("Datetime index must be monotonically increasing") if isinstance(self.value, (slice, hslice)): if self.value.step is not None: raise ValueError("Step must be None") if self.value.start is None and self.value.stop is None: return slice(None, None, None) start_dtc = self.get_dtc_namedtuple(self.value.start, **parse_kwargs) end_dtc = self.get_dtc_namedtuple(self.value.stop, **parse_kwargs) func = jit_reg.resolve_option(dt_nb.index_within_dtc_range_nb, self.jitted) mask = func(ns_index, start_dtc, end_dtc, closed_start=self.closed_start, closed_end=self.closed_end) elif checks.is_sequence(self.value) and not np.isscalar(self.value): func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted) dtcs = map(lambda x: self.get_dtc_namedtuple(x, **parse_kwargs), self.value) masks = map(lambda x: func(ns_index, x), dtcs) mask = functools.reduce(np.logical_or, masks) else: dtc = self.get_dtc_namedtuple(self.value, **parse_kwargs) func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted) mask = func(ns_index, dtc) return MaskIdxr(mask).get(index=index, freq=freq) @define class PointIdxr(UniIdxr, DefineMixin): """Class for resolving index points.""" every: tp.Optional[tp.FrequencyLike] = define.field(default=None) """Frequency either as an integer or timedelta. Gets translated into `on` array by creating a range. If integer, an index sequence from `start` to `end` (exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence from `start` to `end` (inclusive) is created and 'labels' as `kind` is used. If `at_time` is not None and `every` and `on` are None, `every` defaults to one day.""" normalize_every: bool = define.field(default=False) """Normalize start/end dates to midnight before generating date range.""" at_time: tp.Optional[tp.TimeLike] = define.field(default=None) """Time of the day either as a (human-readable) string or `datetime.time`. Every datetime in `on` gets floored to the daily frequency, while `at_time` gets converted into a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_delta`. Index must be datetime-like.""" start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None) """Start index/date. If (human-readable) string, gets converted into a datetime. If `every` is None, gets used to filter the final index array.""" end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None) """End index/date. If (human-readable) string, gets converted into a datetime. If `every` is None, gets used to filter the final index array.""" exact_start: bool = define.field(default=False) """Whether the first index should be exactly `start`. Depending on `every`, the first index picked by `pd.date_range` may happen after `start`. In such a case, `start` gets injected before the first index generated by `pd.date_range`.""" on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None) """Index/label or a sequence of such. Gets converted into datetime format whenever possible.""" add_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None) """Offset to be added to each in `on`. Gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.""" kind: tp.Optional[str] = define.field(default=None) """Kind of data in `on`: indices or labels. If None, gets assigned to `indices` if `on` contains integer data, otherwise to `labels`. If `kind` is 'labels', `on` gets converted into indices using `pd.Index.get_indexer`. Prior to this, gets its timezone aligned to the timezone of the index. If `kind` is 'indices', `on` gets wrapped with NumPy.""" indexer_method: str = define.field(default="bfill") """Method for `pd.Index.get_indexer`. Allows two additional values: "before" and "after".""" indexer_tolerance: tp.Optional[tp.Union[int, tp.TimedeltaLike, tp.IndexLike]] = define.field(default=None) """Tolerance for `pd.Index.get_indexer`. If `at_time` is set and `indexer_method` is neither exact nor nearest, `indexer_tolerance` becomes such that the next element must be within the current day.""" skip_not_found: bool = define.field(default=True) """Whether to drop indices that are -1 (not found).""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if index is None: raise ValueError("Index is required") idxs = get_index_points(index, **self.asdict()) self.check_idxs(idxs, check_minus_one=True) return idxs point_idxr_defaults = {a.name: a.default for a in PointIdxr.fields} def get_index_points( index: tp.Index, every: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["every"], normalize_every: bool = point_idxr_defaults["normalize_every"], at_time: tp.Optional[tp.TimeLike] = point_idxr_defaults["at_time"], start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["start"], end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["end"], exact_start: bool = point_idxr_defaults["exact_start"], on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = point_idxr_defaults["on"], add_delta: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["add_delta"], kind: tp.Optional[str] = point_idxr_defaults["kind"], indexer_method: str = point_idxr_defaults["indexer_method"], indexer_tolerance: str = point_idxr_defaults["indexer_tolerance"], skip_not_found: bool = point_idxr_defaults["skip_not_found"], ) -> tp.Array1d: """Translate indices or labels into index points. See `PointIdxr` for arguments. Usage: * Provide nothing to generate at the beginning: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020-01", "2020-02", freq="1d") >>> vbt.get_index_points(index) array([0]) ``` * Provide `every` as an integer frequency to generate index points using NumPy: ```pycon >>> # Generate a point every five rows >>> vbt.get_index_points(index, every=5) array([ 0, 5, 10, 15, 20, 25, 30]) >>> # Generate a point every five rows starting at 6th row >>> vbt.get_index_points(index, every=5, start=5) array([ 5, 10, 15, 20, 25, 30]) >>> # Generate a point every five rows from 6th to 16th row >>> vbt.get_index_points(index, every=5, start=5, end=15) array([ 5, 10]) ``` * Provide `every` as a time delta frequency to generate index points using Pandas: ```pycon >>> # Generate a point every week >>> vbt.get_index_points(index, every="W") array([ 4, 11, 18, 25]) >>> # Generate a point every second day of the week >>> vbt.get_index_points(index, every="W", add_delta="2d") array([ 6, 13, 20, 27]) >>> # Generate a point every week, starting at 11th row >>> vbt.get_index_points(index, every="W", start=10) array([11, 18, 25]) >>> # Generate a point every week, starting exactly at 11th row >>> vbt.get_index_points(index, every="W", start=10, exact_start=True) array([10, 11, 18, 25]) >>> # Generate a point every week, starting at 2020-01-10 >>> vbt.get_index_points(index, every="W", start="2020-01-10") array([11, 18, 25]) ``` * Instead of using `every`, provide indices explicitly: ```pycon >>> # Generate one point >>> vbt.get_index_points(index, on="2020-01-07") array([6]) >>> # Generate multiple points >>> vbt.get_index_points(index, on=["2020-01-07", "2020-01-14"]) array([ 6, 13]) ``` """ index = dt.prepare_dt_index(index) if on is not None and isinstance(on, str): on = dt.try_align_dt_to_index(on, index) if start is not None and isinstance(start, str): start = dt.try_align_dt_to_index(start, index) if end is not None and isinstance(end, str): end = dt.try_align_dt_to_index(end, index) if every is not None and not checks.is_int(every): every = dt.to_freq(every) start_used = False end_used = False if at_time is not None and every is None and on is None: every = pd.Timedelta(days=1) if every is not None: start_used = True end_used = True if checks.is_int(every): if start is None: start = 0 if end is None: end = len(index) on = np.arange(start, end, every) kind = "indices" else: if start is None: start = 0 if checks.is_int(start): start_date = index[start] else: start_date = start if end is None: end = len(index) - 1 if checks.is_int(end): end_date = index[end] else: end_date = end on = dt.date_range( start_date, end_date, freq=every, tz=index.tz, normalize=normalize_every, inclusive="both", ) if exact_start and on[0] > start_date: on = on.insert(0, start_date) kind = "labels" if kind is None: if on is None: if start is not None: if checks.is_int(start): kind = "indices" else: kind = "labels" else: kind = "indices" else: on = dt.prepare_dt_index(on) if pd.api.types.is_integer_dtype(on): kind = "indices" else: kind = "labels" checks.assert_in(kind, ("indices", "labels")) if on is None: if start is not None: on = start start_used = True else: if kind.lower() in ("labels",): on = index else: on = np.arange(len(index)) on = dt.prepare_dt_index(on) if at_time is not None: checks.assert_instance_of(on, pd.DatetimeIndex) on = on.floor("D") add_time_delta = dt.time_to_timedelta(at_time) if indexer_tolerance is None: indexer_method = indexer_method.lower() if indexer_method in ("pad", "ffill"): indexer_tolerance = add_time_delta elif indexer_method in ("backfill", "bfill"): indexer_tolerance = pd.Timedelta(days=1) - pd.Timedelta(1, "ns") - add_time_delta if add_delta is None: add_delta = add_time_delta else: add_delta += add_time_delta if add_delta is not None: on += dt.to_freq(add_delta) if kind.lower() == "labels": on = dt.try_align_to_dt_index(on, index) if indexer_method is not None: indexer_method = indexer_method.lower() if indexer_method == "before": on = on - pd.Timedelta(1, "ns") indexer_method = "ffill" elif indexer_method == "after": on = on + pd.Timedelta(1, "ns") indexer_method = "bfill" index_points = index.get_indexer(on, method=indexer_method, tolerance=indexer_tolerance) else: index_points = np.asarray(on) if start is not None and not start_used: if not checks.is_int(start): start = index.get_indexer([start], method="bfill").item(0) index_points = index_points[index_points >= start] if end is not None and not end_used: if not checks.is_int(end): end = index.get_indexer([end], method="ffill").item(0) index_points = index_points[index_points <= end] else: index_points = index_points[index_points < end] if skip_not_found: index_points = index_points[index_points != -1] return index_points @define class RangeIdxr(UniIdxr, DefineMixin): """Class for resolving index ranges.""" every: tp.Optional[tp.FrequencyLike] = define.field(default=None) """Frequency either as an integer or timedelta. Gets translated into `start` and `end` arrays by creating a range. If integer, an index sequence from `start` to `end` (exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence from `start` to `end` (inclusive) is created and 'bounds' as `kind` is used. If `start_time` and `end_time` are not None and `every`, `start`, and `end` are None, `every` defaults to one day.""" normalize_every: bool = define.field(default=False) """Normalize start/end dates to midnight before generating date range.""" split_every: bool = define.field(default=True) """Whether to split the sequence generated using `every` into `start` and `end` arrays. After creation, and if `split_every` is True, an index range is created from each pair of elements in the generated sequence. Otherwise, the entire sequence is assigned to `start` and `end`, and only time and delta instructions can be used to further differentiate between them. Forced to False if `every`, `start_time`, and `end_time` are not None and `fixed_start` is False.""" start_time: tp.Optional[tp.TimeLike] = define.field(default=None) """Start time of the day either as a (human-readable) string or `datetime.time`. Every datetime in `start` gets floored to the daily frequency, while `start_time` gets converted into a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_start_delta`. Index must be datetime-like.""" end_time: tp.Optional[tp.TimeLike] = define.field(default=None) """End time of the day either as a (human-readable) string or `datetime.time`. Every datetime in `end` gets floored to the daily frequency, while `end_time` gets converted into a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_end_delta`. Index must be datetime-like.""" lookback_period: tp.Optional[tp.FrequencyLike] = define.field(default=None) """Lookback period either as an integer or offset. If `lookback_period` is set, `start` becomes `end-lookback_period`. If `every` is not None, the sequence is generated from `start+lookback_period` to `end` and then assigned to `end`. If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`. If integer, gets multiplied by the frequency of the index if the index is not integer.""" start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None) """Start index/label or a sequence of such. Gets converted into datetime format whenever possible. Gets broadcasted together with `end`.""" end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None) """End index/label or a sequence of such. Gets converted into datetime format whenever possible. Gets broadcasted together with `start`.""" exact_start: bool = define.field(default=False) """Whether the first index in the `start` array should be exactly `start`. Depending on `every`, the first index picked by `pd.date_range` may happen after `start`. In such a case, `start` gets injected before the first index generated by `pd.date_range`. Cannot be used together with `lookback_period`.""" fixed_start: bool = define.field(default=False) """Whether all indices in the `start` array should be exactly `start`. Works only together with `every`. Cannot be used together with `lookback_period`.""" closed_start: bool = define.field(default=True) """Whether `start` should be inclusive.""" closed_end: bool = define.field(default=False) """Whether `end` should be inclusive.""" add_start_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None) """Offset to be added to each in `start`. If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.""" add_end_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None) """Offset to be added to each in `end`. If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.""" kind: tp.Optional[str] = define.field(default=None) """Kind of data in `on`: indices, labels or bounds. If None, gets assigned to `indices` if `start` and `end` contain integer data, to `bounds` if `start`, `end`, and index are datetime-like, otherwise to `labels`. If `kind` is 'labels', `start` and `end` get converted into indices using `pd.Index.get_indexer`. Prior to this, get their timezone aligned to the timezone of the index. If `kind` is 'indices', `start` and `end` get wrapped with NumPy. If kind` is 'bounds', `vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges` is used.""" skip_not_found: bool = define.field(default=True) """Whether to drop indices that are -1 (not found).""" jitted: tp.JittedOption = define.field(default=None) """Jitting option passed to `vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges`.""" def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if index is None: raise ValueError("Index is required") from vectorbtpro.base.merging import column_stack_arrays start_idxs, end_idxs = get_index_ranges(index, index_freq=freq, **self.asdict()) idxs = column_stack_arrays((start_idxs, end_idxs)) self.check_idxs(idxs, check_minus_one=True) return idxs range_idxr_defaults = {a.name: a.default for a in RangeIdxr.fields} def get_index_ranges( index: tp.Index, index_freq: tp.Optional[tp.FrequencyLike] = None, every: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["every"], normalize_every: bool = range_idxr_defaults["normalize_every"], split_every: bool = range_idxr_defaults["split_every"], start_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["start_time"], end_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["end_time"], lookback_period: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["lookback_period"], start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["start"], end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["end"], exact_start: bool = range_idxr_defaults["exact_start"], fixed_start: bool = range_idxr_defaults["fixed_start"], closed_start: bool = range_idxr_defaults["closed_start"], closed_end: bool = range_idxr_defaults["closed_end"], add_start_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_start_delta"], add_end_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_end_delta"], kind: tp.Optional[str] = range_idxr_defaults["kind"], skip_not_found: bool = range_idxr_defaults["skip_not_found"], jitted: tp.JittedOption = range_idxr_defaults["jitted"], ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Translate indices, labels, or bounds into index ranges. See `RangeIdxr` for arguments. Usage: * Provide nothing to generate one largest index range: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020-01", "2020-02", freq="1d") >>> np.column_stack(vbt.get_index_ranges(index)) array([[ 0, 32]]) ``` * Provide `every` as an integer frequency to generate index ranges using NumPy: ```pycon >>> # Generate a range every five rows >>> np.column_stack(vbt.get_index_ranges(index, every=5)) array([[ 0, 5], [ 5, 10], [10, 15], [15, 20], [20, 25], [25, 30]]) >>> # Generate a range every five rows, starting at 6th row >>> np.column_stack(vbt.get_index_ranges( ... index, ... every=5, ... start=5 ... )) array([[ 5, 10], [10, 15], [15, 20], [20, 25], [25, 30]]) >>> # Generate a range every five rows from 6th to 16th row >>> np.column_stack(vbt.get_index_ranges( ... index, ... every=5, ... start=5, ... end=15 ... )) array([[ 5, 10], [10, 15]]) ``` * Provide `every` as a time delta frequency to generate index ranges using Pandas: ```pycon >>> # Generate a range every week >>> np.column_stack(vbt.get_index_ranges(index, every="W")) array([[ 4, 11], [11, 18], [18, 25]]) >>> # Generate a range every second day of the week >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... add_start_delta="2d" ... )) array([[ 6, 11], [13, 18], [20, 25]]) >>> # Generate a range every week, starting at 11th row >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... start=10 ... )) array([[11, 18], [18, 25]]) >>> # Generate a range every week, starting exactly at 11th row >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... start=10, ... exact_start=True ... )) array([[10, 11], [11, 18], [18, 25]]) >>> # Generate a range every week, starting at 2020-01-10 >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... start="2020-01-10" ... )) array([[11, 18], [18, 25]]) >>> # Generate a range every week, each starting at 2020-01-10 >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... start="2020-01-10", ... fixed_start=True ... )) array([[11, 18], [11, 25]]) >>> # Generate an expanding range that increments by week >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... start=0, ... exact_start=True, ... fixed_start=True ... )) array([[ 0, 4], [ 0, 11], [ 0, 18], [ 0, 25]]) ``` * Use a look-back period (instead of an end index): ```pycon >>> # Generate a range every week, looking 5 days back >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... lookback_period=5 ... )) array([[ 6, 11], [13, 18], [20, 25]]) >>> # Generate a range every week, looking 2 weeks back >>> np.column_stack(vbt.get_index_ranges( ... index, ... every="W", ... lookback_period="2W" ... )) array([[ 0, 11], [ 4, 18], [11, 25]]) ``` * Instead of using `every`, provide start and end indices explicitly: ```pycon >>> # Generate one range >>> np.column_stack(vbt.get_index_ranges( ... index, ... start="2020-01-01", ... end="2020-01-07" ... )) array([[0, 6]]) >>> # Generate ranges between multiple dates >>> np.column_stack(vbt.get_index_ranges( ... index, ... start=["2020-01-01", "2020-01-07"], ... end=["2020-01-07", "2020-01-14"] ... )) array([[ 0, 6], [ 6, 13]]) >>> # Generate ranges with a fixed start >>> np.column_stack(vbt.get_index_ranges( ... index, ... start="2020-01-01", ... end=["2020-01-07", "2020-01-14"] ... )) array([[ 0, 6], [ 0, 13]]) ``` * Use `closed_start` and `closed_end` to exclude any of the bounds: ```pycon >>> # Generate ranges between multiple dates >>> # by excluding the start date and including the end date >>> np.column_stack(vbt.get_index_ranges( ... index, ... start=["2020-01-01", "2020-01-07"], ... end=["2020-01-07", "2020-01-14"], ... closed_start=False, ... closed_end=True ... )) array([[ 1, 7], [ 7, 14]]) ``` """ from vectorbtpro.base.indexes import repeat_index from vectorbtpro.base.resampling.base import Resampler index = dt.prepare_dt_index(index) if isinstance(index, pd.DatetimeIndex): if start is not None: start = dt.try_align_to_dt_index(start, index) if isinstance(start, pd.DatetimeIndex): start = start.tz_localize(None) if end is not None: end = dt.try_align_to_dt_index(end, index) if isinstance(end, pd.DatetimeIndex): end = end.tz_localize(None) naive_index = index.tz_localize(None) else: if start is not None: if not isinstance(start, pd.Index): try: start = pd.Index(start) except Exception as e: start = pd.Index([start]) if end is not None: if not isinstance(end, pd.Index): try: end = pd.Index(end) except Exception as e: end = pd.Index([end]) naive_index = index if every is not None and not checks.is_int(every): every = dt.to_freq(every) if lookback_period is not None and not checks.is_int(lookback_period): lookback_period = dt.to_freq(lookback_period) if fixed_start and lookback_period is not None: raise ValueError("Cannot use fixed_start and lookback_period together") if exact_start and lookback_period is not None: raise ValueError("Cannot use exact_start and lookback_period together") if start_time is not None or end_time is not None: if every is None and start is None and end is None: every = pd.Timedelta(days=1) if every is not None: if not fixed_start: if start_time is None and end_time is not None: start_time = time(0, 0, 0, 0) closed_start = True if start_time is not None and end_time is None: end_time = time(0, 0, 0, 0) closed_end = False if start_time is not None and end_time is not None and not fixed_start: split_every = False if checks.is_int(every): if start is None: start = 0 else: start = start[0] if end is None: end = len(naive_index) else: end = end[-1] if closed_end: end -= 1 if lookback_period is None: new_index = np.arange(start, end + 1, every) if not split_every: start = end = new_index else: if fixed_start: start = np.full(len(new_index) - 1, new_index[0]) else: start = new_index[:-1] end = new_index[1:] else: end = np.arange(start + lookback_period, end + 1, every) start = end - lookback_period kind = "indices" lookback_period = None else: if start is None: start = 0 else: start = start[0] if checks.is_int(start): start_date = naive_index[start] else: start_date = start if end is None: end = len(naive_index) - 1 else: end = end[-1] if checks.is_int(end): end_date = naive_index[end] else: end_date = end if lookback_period is None: new_index = dt.date_range( start_date, end_date, freq=every, normalize=normalize_every, inclusive="both", ) if exact_start and new_index[0] > start_date: new_index = new_index.insert(0, start_date) if not split_every: start = end = new_index else: if fixed_start: start = repeat_index(new_index[[0]], len(new_index) - 1) else: start = new_index[:-1] end = new_index[1:] else: if checks.is_int(lookback_period): lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq) if isinstance(lookback_period, BaseOffset): end = dt.date_range( start_date, end_date, freq=every, normalize=normalize_every, inclusive="both", ) start = end - lookback_period start_mask = start >= start_date start = start[start_mask] end = end[start_mask] else: end = dt.date_range( start_date + lookback_period, end_date, freq=every, normalize=normalize_every, inclusive="both", ) start = end - lookback_period kind = "bounds" lookback_period = None if kind is None: if start is None and end is None: kind = "indices" else: if start is not None: ref_index = start if end is not None: ref_index = end if pd.api.types.is_integer_dtype(ref_index): kind = "indices" elif isinstance(ref_index, pd.DatetimeIndex) and isinstance(naive_index, pd.DatetimeIndex): kind = "bounds" else: kind = "labels" checks.assert_in(kind, ("indices", "labels", "bounds")) if end is None: if kind.lower() in ("labels", "bounds"): end = pd.Index([naive_index[-1]]) else: end = pd.Index([len(naive_index)]) if start is not None and lookback_period is not None: raise ValueError("Cannot use start and lookback_period together") if start is None: if lookback_period is None: if kind.lower() in ("labels", "bounds"): start = pd.Index([naive_index[0]]) else: start = pd.Index([0]) else: if checks.is_int(lookback_period) and not pd.api.types.is_integer_dtype(end): lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq) start = end - lookback_period if len(start) == 1 and len(end) > 1: start = repeat_index(start, len(end)) elif len(start) > 1 and len(end) == 1: end = repeat_index(end, len(start)) checks.assert_len_equal(start, end) if start_time is not None: checks.assert_instance_of(start, pd.DatetimeIndex) start = start.floor("D") add_start_time_delta = dt.time_to_timedelta(start_time) if add_start_delta is None: add_start_delta = add_start_time_delta else: add_start_delta += add_start_time_delta else: add_start_time_delta = None if end_time is not None: checks.assert_instance_of(end, pd.DatetimeIndex) end = end.floor("D") add_end_time_delta = dt.time_to_timedelta(end_time) if add_start_time_delta is not None: if add_end_time_delta < add_start_delta: add_end_time_delta += pd.Timedelta(days=1) if add_end_delta is None: add_end_delta = add_end_time_delta else: add_end_delta += add_end_time_delta if add_start_delta is not None: start += dt.to_freq(add_start_delta) if add_end_delta is not None: end += dt.to_freq(add_end_delta) if kind.lower() == "bounds": range_starts, range_ends = Resampler.map_bounds_to_source_ranges( source_index=naive_index.values, target_lbound_index=start.values, target_rbound_index=end.values, closed_lbound=closed_start, closed_rbound=closed_end, skip_not_found=skip_not_found, jitted=jitted, ) else: if kind.lower() == "labels": range_starts = np.empty(len(start), dtype=int_) range_ends = np.empty(len(end), dtype=int_) range_index = pd.Series(np.arange(len(naive_index)), index=naive_index) for i in range(len(range_starts)): selected_range = range_index[start[i] : end[i]] if len(selected_range) > 0 and not closed_start and selected_range.index[0] == start[i]: selected_range = selected_range.iloc[1:] if len(selected_range) > 0 and not closed_end and selected_range.index[-1] == end[i]: selected_range = selected_range.iloc[:-1] if len(selected_range) > 0: range_starts[i] = selected_range.iloc[0] range_ends[i] = selected_range.iloc[-1] else: range_starts[i] = -1 range_ends[i] = -1 else: if not closed_start: start = start + 1 if closed_end: end = end + 1 range_starts = np.asarray(start) range_ends = np.asarray(end) if skip_not_found: valid_mask = (range_starts != -1) & (range_ends != -1) range_starts = range_starts[valid_mask] range_ends = range_ends[valid_mask] if np.any(range_starts >= range_ends): raise ValueError("Some start indices are equal to or higher than end indices") return range_starts, range_ends @define class AutoIdxr(UniIdxr, DefineMixin): """Class for resolving indices, datetime-like objects, frequency-like objects, and labels for one axis.""" value: tp.Union[ None, tp.PosSel, tp.LabelSel, tp.MaybeSequence[tp.MaybeSequence[int]], tp.MaybeSequence[tp.Label], tp.MaybeSequence[tp.DatetimeLike], tp.MaybeSequence[tp.DTCLike], tp.FrequencyLike, tp.Slice, ] = define.field() """One or more integer indices, datetime-like objects, frequency-like objects, or labels. Can also be an instance of `vectorbtpro.utils.selection.PosSel` holding position(s) and `vectorbtpro.utils.selection.LabelSel` holding label(s).""" closed_start: bool = define.optional_field() """Whether slice start should be inclusive.""" closed_end: bool = define.optional_field() """Whether slice end should be inclusive.""" indexer_method: tp.Optional[str] = define.optional_field() """Method for `pd.Index.get_indexer`.""" below_to_zero: bool = define.optional_field() """Whether to place 0 instead of -1 if `AutoIdxr.value` is below the first index.""" above_to_len: bool = define.optional_field() """Whether to place `len(index)` instead of -1 if `AutoIdxr.value` is above the last index.""" level: tp.MaybeLevelSequence = define.field(default=None) """One or more levels. If `level` is not None and `kind` is None, `kind` becomes "labels".""" kind: tp.Optional[str] = define.field(default=None) """Kind of value. Allowed are * "position(s)" for `PosIdxr` * "mask" for `MaskIdxr` * "label(s)" for `LabelIdxr` * "datetime" for `DatetimeIdxr` * "dtc": for `DTCIdxr` * "frequency" for `PointIdxr` If None, will (try to) determine automatically based on the type of indices.""" idxr_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to the selected indexer.""" def __init__(self, *args, **kwargs) -> None: idxr_kwargs = kwargs.pop("idxr_kwargs", None) if idxr_kwargs is None: idxr_kwargs = {} else: idxr_kwargs = dict(idxr_kwargs) builtin_keys = {a.name for a in self.fields} for k in list(kwargs.keys()): if k not in builtin_keys: idxr_kwargs[k] = kwargs.pop(k) DefineMixin.__init__(self, *args, idxr_kwargs=idxr_kwargs, **kwargs) def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.MaybeIndexArray: if self.value is None: return slice(None, None, None) value = self.value kind = self.kind if self.level is not None: from vectorbtpro.base.indexes import select_levels if index is None: raise ValueError("Index is required") index = select_levels(index, self.level) if kind is None: kind = "labels" if self.idxr_kwargs is None: idxr_kwargs = self.idxr_kwargs else: idxr_kwargs = None if idxr_kwargs is None: idxr_kwargs = {} def _dtc_check_func(dtc): return ( not dtc.has_full_datetime() and self.indexer_method in (MISSING, None) and self.below_to_zero is MISSING and self.above_to_len is MISSING ) if kind is None: if isinstance(value, PosSel): kind = "positions" value = value.value elif isinstance(value, LabelSel): kind = "labels" value = value.value elif isinstance(value, (slice, hslice)): if checks.is_int(value.start) or checks.is_int(value.stop): kind = "positions" elif value.start is None and value.stop is None: kind = "positions" else: if index is None: raise ValueError("Index is required") if isinstance(index, pd.DatetimeIndex): if dt.DTC.is_parsable(value.start, check_func=_dtc_check_func) or dt.DTC.is_parsable( value.stop, check_func=_dtc_check_func ): kind = "dtc" else: kind = "datetime" else: kind = "labels" elif (checks.is_sequence(value) and not np.isscalar(value)) and ( index is None or ( not isinstance(index, pd.MultiIndex) or (isinstance(index, pd.MultiIndex) and isinstance(value[0], tuple)) ) ): if checks.is_bool(value[0]): kind = "mask" elif checks.is_int(value[0]): kind = "positions" elif ( (index is None or not isinstance(index, pd.MultiIndex) or not isinstance(value[0], tuple)) and checks.is_sequence(value[0]) and len(value[0]) == 2 and checks.is_int(value[0][0]) and checks.is_int(value[0][1]) ): kind = "positions" else: if index is None: raise ValueError("Index is required") elif isinstance(index, pd.DatetimeIndex): if dt.DTC.is_parsable(value[0], check_func=_dtc_check_func): kind = "dtc" else: kind = "datetime" else: kind = "labels" else: if checks.is_bool(value): kind = "mask" elif checks.is_int(value): kind = "positions" else: if index is None: raise ValueError("Index is required") if isinstance(index, pd.DatetimeIndex): if dt.DTC.is_parsable(value, check_func=_dtc_check_func): kind = "dtc" elif isinstance(value, str): try: if not value.isupper() and not value.islower(): raise Exception # "2020" shouldn't be a frequency _ = dt.to_freq(value) kind = "frequency" except Exception as e: try: _ = dt.to_timestamp(value) kind = "datetime" except Exception as e: raise ValueError(f"'{value}' is neither a frequency nor a datetime") elif checks.is_frequency(value): kind = "frequency" else: kind = "datetime" else: kind = "labels" def _expand_target_kwargs(target_cls, **target_kwargs): source_arg_names = {a.name for a in self.fields if a.default is MISSING} target_arg_names = {a.name for a in target_cls.fields} for arg_name in source_arg_names: if arg_name in target_arg_names: arg_value = getattr(self, arg_name) if arg_value is not MISSING: target_kwargs[arg_name] = arg_value return target_kwargs if kind.lower() in ("position", "positions"): idx = PosIdxr(value, **_expand_target_kwargs(PosIdxr, **idxr_kwargs)) elif kind.lower() == "mask": idx = MaskIdxr(value, **_expand_target_kwargs(MaskIdxr, **idxr_kwargs)) elif kind.lower() in ("label", "labels"): idx = LabelIdxr(value, **_expand_target_kwargs(LabelIdxr, **idxr_kwargs)) elif kind.lower() == "datetime": idx = DatetimeIdxr(value, **_expand_target_kwargs(DatetimeIdxr, **idxr_kwargs)) elif kind.lower() == "dtc": idx = DTCIdxr(value, **_expand_target_kwargs(DTCIdxr, **idxr_kwargs)) elif kind.lower() == "frequency": idx = PointIdxr(every=value, **_expand_target_kwargs(PointIdxr, **idxr_kwargs)) else: raise ValueError(f"Invalid kind: '{kind}'") return idx.get(index=index, freq=freq) @define class RowIdxr(IdxrBase, DefineMixin): """Class for resolving row indices.""" idxr: object = define.field() """Indexer. Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`.""" idxr_kwargs: tp.KwargsLike = define.field() """Keyword arguments passed to `AutoIdxr`.""" def __init__(self, idxr: object, **idxr_kwargs) -> None: DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs)) def get( self, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, template_context: tp.KwargsLike = None, ) -> tp.MaybeIndexArray: idxr = self.idxr if isinstance(idxr, CustomTemplate): _template_context = merge_dicts(dict(index=index, freq=freq), template_context) idxr = idxr.substitute(_template_context, eval_id="idxr") if not isinstance(idxr, UniIdxr): if isinstance(idxr, IdxrBase): raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr") idxr = AutoIdxr(idxr, **self.idxr_kwargs) return idxr.get(index=index, freq=freq) @define class ColIdxr(IdxrBase, DefineMixin): """Class for resolving column indices.""" idxr: object = define.field() """Indexer. Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`.""" idxr_kwargs: tp.KwargsLike = define.field() """Keyword arguments passed to `AutoIdxr`.""" def __init__(self, idxr: object, **idxr_kwargs) -> None: DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs)) def get( self, columns: tp.Optional[tp.Index] = None, template_context: tp.KwargsLike = None, ) -> tp.MaybeIndexArray: idxr = self.idxr if isinstance(idxr, CustomTemplate): _template_context = merge_dicts(dict(columns=columns), template_context) idxr = idxr.substitute(_template_context, eval_id="idxr") if not isinstance(idxr, UniIdxr): if isinstance(idxr, IdxrBase): raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr") idxr = AutoIdxr(idxr, **self.idxr_kwargs) return idxr.get(index=columns) @define class Idxr(IdxrBase, DefineMixin): """Class for resolving indices.""" idxrs: tp.Tuple[object, ...] = define.field() """A tuple of one or more indexers. If one indexer is provided, can be an instance of `RowIdxr` or `ColIdxr`, a custom template, or a value to wrapped with `RowIdxr`. If two indexers are provided, can be an instance of `RowIdxr` and `ColIdxr` respectively, or a value to wrapped with `RowIdxr` and `ColIdxr` respectively.""" idxr_kwargs: tp.KwargsLike = define.field() """Keyword arguments passed to `RowIdxr` and `ColIdxr`.""" def __init__(self, *idxrs: object, **idxr_kwargs) -> None: DefineMixin.__init__(self, idxrs=idxrs, idxr_kwargs=hdict(idxr_kwargs)) def get( self, index: tp.Optional[tp.Index] = None, columns: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, template_context: tp.KwargsLike = None, ) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]: if len(self.idxrs) == 0: raise ValueError("Must provide at least one indexer") elif len(self.idxrs) == 1: idxr = self.idxrs[0] if isinstance(idxr, CustomTemplate): _template_context = merge_dicts(dict(index=index, columns=columns, freq=freq), template_context) idxr = idxr.substitute(_template_context, eval_id="idxr") if isinstance(idxr, tuple): return type(self)(*idxr).get( index=index, columns=columns, freq=freq, template_context=template_context, ) return type(self)(idxr).get( index=index, columns=columns, freq=freq, template_context=template_context, ) if isinstance(idxr, ColIdxr): row_idxr = None col_idxr = idxr else: row_idxr = idxr col_idxr = None elif len(self.idxrs) == 2: row_idxr = self.idxrs[0] col_idxr = self.idxrs[1] else: raise ValueError("Must provide at most two indexers") if not isinstance(row_idxr, RowIdxr): if isinstance(row_idxr, (ColIdxr, Idxr)): raise TypeError(f"Indexer {type(row_idxr)} not supported as a row indexer") row_idxr = RowIdxr(row_idxr, **self.idxr_kwargs) row_idxs = row_idxr.get(index=index, freq=freq, template_context=template_context) if not isinstance(col_idxr, ColIdxr): if isinstance(col_idxr, (RowIdxr, Idxr)): raise TypeError(f"Indexer {type(col_idxr)} not supported as a column indexer") col_idxr = ColIdxr(col_idxr, **self.idxr_kwargs) col_idxs = col_idxr.get(columns=columns, template_context=template_context) return row_idxs, col_idxs def get_idxs( idxr: object, index: tp.Optional[tp.Index] = None, columns: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]: """Translate indexer to row and column indices. If `idxr` is not an indexer class, wraps it with `Idxr`. Keyword arguments are passed when constructing a new `Idxr`.""" if not isinstance(idxr, Idxr): idxr = Idxr(idxr, **kwargs) return idxr.get(index=index, columns=columns, freq=freq, template_context=template_context) class index_dict(pdict): """Dict that contains indexer objects as keys and values to be set as values. Each indexer object must be hashable. To make a slice hashable, use `hslice`. To make an array hashable, convert it into a tuple. To set a default value, use the `_def` key (case-sensitive!).""" pass IdxSetterT = tp.TypeVar("IdxSetterT", bound="IdxSetter") @define class IdxSetter(DefineMixin): """Class for setting values based on indexing.""" idx_items: tp.List[tp.Tuple[object, tp.ArrayLike]] = define.field() """Items where the first element is an indexer and the second element is a value to be set.""" @classmethod def set_row_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None: """Set row indices in an array.""" from vectorbtpro.base.reshaping import broadcast_array_to if not isinstance(v, np.ndarray): v = np.asarray(v) single_v = v.size == 1 or (v.ndim == 2 and v.shape[0] == 1) if arr.ndim == 2: single_row = not isinstance(idxs, slice) and (np.isscalar(idxs) or idxs.size == 1) if not single_row: if v.ndim == 1 and v.size > 1: v = v[:, None] if isinstance(idxs, np.ndarray) and idxs.ndim == 2: if not single_v: if arr.ndim == 2: v = broadcast_array_to(v, (len(idxs), arr.shape[1])) else: v = broadcast_array_to(v, (len(idxs),)) for i in range(len(idxs)): _slice = slice(idxs[i, 0], idxs[i, 1]) if not single_v: cls.set_row_idxs(arr, _slice, v[[i]]) else: cls.set_row_idxs(arr, _slice, v) else: arr[idxs] = v @classmethod def set_col_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None: """Set column indices in an array.""" from vectorbtpro.base.reshaping import broadcast_array_to if not isinstance(v, np.ndarray): v = np.asarray(v) single_v = v.size == 1 or (v.ndim == 2 and v.shape[1] == 1) if isinstance(idxs, np.ndarray) and idxs.ndim == 2: if not single_v: v = broadcast_array_to(v, (arr.shape[0], len(idxs))) for j in range(len(idxs)): _slice = slice(idxs[j, 0], idxs[j, 1]) if not single_v: cls.set_col_idxs(arr, _slice, v[:, [j]]) else: cls.set_col_idxs(arr, _slice, v) else: arr[:, idxs] = v @classmethod def set_row_and_col_idxs( cls, arr: tp.Array, row_idxs: tp.MaybeIndexArray, col_idxs: tp.MaybeIndexArray, v: tp.Any, ) -> None: """Set row and column indices in an array.""" from vectorbtpro.base.reshaping import broadcast_array_to if not isinstance(v, np.ndarray): v = np.asarray(v) single_v = v.size == 1 if ( isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2 and isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2 ): if not single_v: v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) for i in range(len(row_idxs)): for j in range(len(col_idxs)): row_slice = slice(row_idxs[i, 0], row_idxs[i, 1]) col_slice = slice(col_idxs[j, 0], col_idxs[j, 1]) if not single_v: cls.set_row_and_col_idxs(arr, row_slice, col_slice, v[i, j]) else: cls.set_row_and_col_idxs(arr, row_slice, col_slice, v) elif isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2: if not single_v: if isinstance(col_idxs, slice): col_idxs = np.arange(arr.shape[1])[col_idxs] v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) for i in range(len(row_idxs)): row_slice = slice(row_idxs[i, 0], row_idxs[i, 1]) if not single_v: cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v[[i]]) else: cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v) elif isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2: if not single_v: if isinstance(row_idxs, slice): row_idxs = np.arange(arr.shape[0])[row_idxs] v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) for j in range(len(col_idxs)): col_slice = slice(col_idxs[j, 0], col_idxs[j, 1]) if not single_v: cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v[:, [j]]) else: cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v) else: if np.isscalar(row_idxs) or np.isscalar(col_idxs): arr[row_idxs, col_idxs] = v elif np.isscalar(v) and (isinstance(row_idxs, slice) or isinstance(col_idxs, slice)): arr[row_idxs, col_idxs] = v elif np.isscalar(v): arr[np.ix_(row_idxs, col_idxs)] = v else: if isinstance(row_idxs, slice): row_idxs = np.arange(arr.shape[0])[row_idxs] if isinstance(col_idxs, slice): col_idxs = np.arange(arr.shape[1])[col_idxs] v = broadcast_array_to(v, (len(row_idxs), len(col_idxs))) arr[np.ix_(row_idxs, col_idxs)] = v def get_set_meta( self, shape: tp.ShapeLike, index: tp.Optional[tp.Index] = None, columns: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.FrequencyLike] = None, template_context: tp.KwargsLike = None, ) -> tp.Kwargs: """Get meta of setting operations in `IdxSetter.idx_items`.""" from vectorbtpro.base.reshaping import to_tuple_shape shape = to_tuple_shape(shape) rows_changed = False cols_changed = False set_funcs = [] default = None for idxr, v in self.idx_items: if isinstance(idxr, str) and idxr == "_def": if default is None: default = v continue row_idxs, col_idxs = get_idxs( idxr, index=index, columns=columns, freq=freq, template_context=template_context, ) if isinstance(v, CustomTemplate): _template_context = merge_dicts( dict( idxr=idxr, row_idxs=row_idxs, col_idxs=col_idxs, ), template_context, ) v = v.substitute(_template_context, eval_id="set") if not isinstance(v, np.ndarray): v = np.asarray(v) def _check_use_idxs(idxs): use_idxs = True if isinstance(idxs, slice): if idxs.start is None and idxs.stop is None and idxs.step is None: use_idxs = False if isinstance(idxs, np.ndarray): if idxs.size == 0: use_idxs = False return use_idxs use_row_idxs = _check_use_idxs(row_idxs) use_col_idxs = _check_use_idxs(col_idxs) if use_row_idxs and use_col_idxs: set_funcs.append(partial(self.set_row_and_col_idxs, row_idxs=row_idxs, col_idxs=col_idxs, v=v)) rows_changed = True cols_changed = True elif use_col_idxs: set_funcs.append(partial(self.set_col_idxs, idxs=col_idxs, v=v)) if checks.is_int(col_idxs): if v.size > 1: rows_changed = True else: if v.ndim == 2: if v.shape[0] > 1: rows_changed = True cols_changed = True else: set_funcs.append(partial(self.set_row_idxs, idxs=row_idxs, v=v)) if use_row_idxs: rows_changed = True if len(shape) == 2: if checks.is_int(row_idxs): if v.size > 1: cols_changed = True else: if v.ndim == 2: if v.shape[1] > 1: cols_changed = True return dict( default=default, set_funcs=set_funcs, rows_changed=rows_changed, cols_changed=cols_changed, ) def set(self, arr: tp.Array, set_funcs: tp.Optional[tp.Sequence[tp.Callable]] = None, **kwargs) -> None: """Set values of a NumPy array based on `IdxSetter.get_set_meta`.""" if set_funcs is None: set_meta = self.get_set_meta(arr.shape, **kwargs) set_funcs = set_meta["set_funcs"] for set_op in set_funcs: set_op(arr) def set_pd(self, pd_arr: tp.SeriesFrame, **kwargs) -> None: """Set values of a Pandas array based on `IdxSetter.get_set_meta`.""" from vectorbtpro.base.indexes import get_index index = get_index(pd_arr, 0) columns = get_index(pd_arr, 1) freq = dt.infer_index_freq(index) self.set(pd_arr.values, index=index, columns=columns, freq=freq, **kwargs) def fill_and_set( self, shape: tp.ShapeLike, keep_flex: bool = False, fill_value: tp.Scalar = np.nan, **kwargs, ) -> tp.Array: """Fill a new array and set its values based on `IdxSetter.get_set_meta`. If `keep_flex` is True, will return the most memory-efficient array representation capable of flexible indexing. If `fill_value` is None, will search for the `_def` key in `IdxSetter.idx_items`. If there's none, will be set to NaN.""" set_meta = self.get_set_meta(shape, **kwargs) if set_meta["default"] is not None: fill_value = set_meta["default"] if isinstance(fill_value, str): dtype = object else: dtype = None if keep_flex and not set_meta["cols_changed"] and not set_meta["rows_changed"]: arr = np.full((1,) if len(shape) == 1 else (1, 1), fill_value, dtype=dtype) elif keep_flex and not set_meta["cols_changed"]: arr = np.full(shape if len(shape) == 1 else (shape[0], 1), fill_value, dtype=dtype) elif keep_flex and not set_meta["rows_changed"]: arr = np.full((1, shape[1]), fill_value, dtype=dtype) else: arr = np.full(shape, fill_value, dtype=dtype) self.set(arr, set_funcs=set_meta["set_funcs"]) return arr class IdxSetterFactory(Base): """Class for building index setters.""" def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: """Get an instance of `IdxSetter` or a dict of such instances - one per array name.""" raise NotImplementedError @define class IdxDict(IdxSetterFactory, DefineMixin): """Class for building an index setter from a dict.""" index_dct: dict = define.field() """Dict that contains indexer objects as keys and values to be set as values.""" def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: return IdxSetter(list(self.index_dct.items())) @define class IdxSeries(IdxSetterFactory, DefineMixin): """Class for building an index setter from a Series.""" sr: tp.AnyArray1d = define.field() """Series or any array-like object to create the Series from.""" split: bool = define.field(default=False) """Whether to split the setting operation. If False, will set all values using a single operation. Otherwise, will do one operation per element.""" idx_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `idx` if the indexer isn't an instance of `Idxr`.""" def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: sr = self.sr split = self.split idx_kwargs = self.idx_kwargs if idx_kwargs is None: idx_kwargs = {} if not isinstance(sr, pd.Series): sr = pd.Series(sr) if split: idx_items = list(sr.items()) else: idx_items = [(sr.index, sr.values)] new_idx_items = [] for idxr, v in idx_items: if idxr is None: raise ValueError("Indexer cannot be None") if not isinstance(idxr, Idxr): idxr = idx(idxr, **idx_kwargs) new_idx_items.append((idxr, v)) return IdxSetter(new_idx_items) @define class IdxFrame(IdxSetterFactory, DefineMixin): """Class for building an index setter from a DataFrame.""" df: tp.AnyArray2d = define.field() """DataFrame or any array-like object to create the DataFrame from.""" split: tp.Union[bool, str] = define.field(default=False) """Whether to split the setting operation. If False, will set all values using a single operation. Otherwise, the following options are supported: * 'columns': one operation per column * 'rows': one operation per row * True or 'elements': one operation per element""" rowidx_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`.""" colidx_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`.""" def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: df = self.df split = self.split rowidx_kwargs = self.rowidx_kwargs colidx_kwargs = self.colidx_kwargs if rowidx_kwargs is None: rowidx_kwargs = {} if colidx_kwargs is None: colidx_kwargs = {} if not isinstance(df, pd.DataFrame): df = pd.DataFrame(df) if isinstance(split, bool): if split: split = "elements" else: split = None if split is not None: if split.lower() == "columns": idx_items = [] for col, sr in df.items(): idx_items.append((sr.index, col, sr.values)) elif split.lower() == "rows": idx_items = [] for row, sr in df.iterrows(): idx_items.append((row, df.columns, sr.values)) elif split.lower() == "elements": idx_items = [] for col, sr in df.items(): for row, v in sr.items(): idx_items.append((row, col, v)) else: raise ValueError(f"Invalid split: '{split}'") else: idx_items = [(df.index, df.columns, df.values)] new_idx_items = [] for row_idxr, col_idxr, v in idx_items: if row_idxr is None: raise ValueError("Row indexer cannot be None") if col_idxr is None: raise ValueError("Column indexer cannot be None") if row_idxr is not None and not isinstance(row_idxr, RowIdxr): row_idxr = rowidx(row_idxr, **rowidx_kwargs) if col_idxr is not None and not isinstance(col_idxr, ColIdxr): col_idxr = colidx(col_idxr, **colidx_kwargs) new_idx_items.append((idx(row_idxr, col_idxr), v)) return IdxSetter(new_idx_items) @define class IdxRecords(IdxSetterFactory, DefineMixin): """Class for building index setters from records - one per field.""" records: tp.RecordsLike = define.field() """Series, DataFrame, or any sequence of mapping-like objects. If a Series or DataFrame and the index is not a default range, the index will become a row field. If a custom row field is provided, the index will be ignored.""" row_field: tp.Union[None, bool, tp.Label] = define.field(default=None) """Row field. If None or True, will search for "row", "index", "open time", and "date" (case-insensitive). If `IdxRecords.records` is a Series or DataFrame, will also include the index name if the index is not a default range. If a record doesn't have a row field, all rows will be set. If there's no row and column field, the field value will become the default of the entire array.""" col_field: tp.Union[None, bool, tp.Label] = define.field(default=None) """Column field. If None or True, will search for "col", "column", and "symbol" (case-insensitive). If a record doesn't have a column field, all columns will be set. If there's no row and column field, the field value will become the default of the entire array.""" rowidx_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`.""" colidx_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`.""" def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]: records = self.records row_field = self.row_field col_field = self.col_field rowidx_kwargs = self.rowidx_kwargs colidx_kwargs = self.colidx_kwargs if rowidx_kwargs is None: rowidx_kwargs = {} if colidx_kwargs is None: colidx_kwargs = {} default_index = False index_field = None if isinstance(records, pd.Series): records = records.to_frame() if isinstance(records, pd.DataFrame): records = records if checks.is_default_index(records.index): default_index = True records = records.reset_index(drop=default_index) if not default_index: index_field = records.columns[0] records = records.itertuples(index=False) def _resolve_field_meta(fields): _row_field = row_field _row_kind = None _col_field = col_field _col_kind = None row_fields = set() col_fields = set() for field in fields: if isinstance(field, str) and index_field is not None and field == index_field: row_fields.add((field, None)) if isinstance(field, str) and field.lower() in ("row", "index"): row_fields.add((field, None)) if isinstance(field, str) and field.lower() in ("open time", "date", "datetime"): if (field, None) in row_fields: row_fields.remove((field, None)) row_fields.add((field, "datetime")) if isinstance(field, str) and field.lower() in ("col", "column"): col_fields.add((field, None)) if isinstance(field, str) and field.lower() == "symbol": if (field, None) in col_fields: col_fields.remove((field, None)) col_fields.add((field, "labels")) if _row_field in (None, True): if len(row_fields) == 0: if _row_field is True: raise ValueError("Cannot find row field") _row_field = None elif len(row_fields) == 1: _row_field, _row_kind = row_fields.pop() else: raise ValueError("Multiple row field candidates") elif _row_field is False: _row_field = None if _col_field in (None, True): if len(col_fields) == 0: if _col_field is True: raise ValueError("Cannot find column field") _col_field = None elif len(col_fields) == 1: _col_field, _col_kind = col_fields.pop() else: raise ValueError("Multiple column field candidates") elif _col_field is False: _col_field = None field_meta = dict() field_meta["row_field"] = _row_field field_meta["row_kind"] = _row_kind field_meta["col_field"] = _col_field field_meta["col_kind"] = _col_kind return field_meta idx_items = dict() for r in records: r = to_field_mapping(r) field_meta = _resolve_field_meta(r.keys()) if field_meta["row_field"] is None: row_idxr = None else: row_idxr = r.get(field_meta["row_field"], None) if row_idxr == "_def": row_idxr = None if row_idxr is not None and not isinstance(row_idxr, RowIdxr): _rowidx_kwargs = dict(rowidx_kwargs) if field_meta["row_kind"] is not None and "kind" not in _rowidx_kwargs: _rowidx_kwargs["kind"] = field_meta["row_kind"] row_idxr = rowidx(row_idxr, **_rowidx_kwargs) if field_meta["col_field"] is None: col_idxr = None else: col_idxr = r.get(field_meta["col_field"], None) if col_idxr is not None and not isinstance(col_idxr, ColIdxr): _colidx_kwargs = dict(colidx_kwargs) if field_meta["col_kind"] is not None and "kind" not in _colidx_kwargs: _colidx_kwargs["kind"] = field_meta["col_kind"] col_idxr = colidx(col_idxr, **_colidx_kwargs) if isinstance(col_idxr, str) and col_idxr == "_def": col_idxr = None item_produced = False for k, v in r.items(): if index_field is not None and k == index_field: continue if field_meta["row_field"] is not None and k == field_meta["row_field"]: continue if field_meta["col_field"] is not None and k == field_meta["col_field"]: continue if k not in idx_items: idx_items[k] = [] if row_idxr is None and col_idxr is None: idx_items[k].append(("_def", v)) else: idx_items[k].append((idx(row_idxr, col_idxr), v)) item_produced = True if not item_produced: raise ValueError(f"Record {r} has no fields to set") idx_setters = dict() for k, v in idx_items.items(): idx_setters[k] = IdxSetter(v) return idx_setters posidx = PosIdxr """Shortcut for `PosIdxr`.""" __pdoc__["posidx"] = False maskidx = MaskIdxr """Shortcut for `MaskIdxr`.""" __pdoc__["maskidx"] = False lbidx = LabelIdxr """Shortcut for `LabelIdxr`.""" __pdoc__["lbidx"] = False dtidx = DatetimeIdxr """Shortcut for `DatetimeIdxr`.""" __pdoc__["dtidx"] = False dtcidx = DTCIdxr """Shortcut for `DTCIdxr`.""" __pdoc__["dtcidx"] = False pointidx = PointIdxr """Shortcut for `PointIdxr`.""" __pdoc__["pointidx"] = False rangeidx = RangeIdxr """Shortcut for `RangeIdxr`.""" __pdoc__["rangeidx"] = False autoidx = AutoIdxr """Shortcut for `AutoIdxr`.""" __pdoc__["autoidx"] = False rowidx = RowIdxr """Shortcut for `RowIdxr`.""" __pdoc__["rowidx"] = False colidx = ColIdxr """Shortcut for `ColIdxr`.""" __pdoc__["colidx"] = False idx = Idxr """Shortcut for `Idxr`.""" __pdoc__["idx"] = False # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Functions for merging arrays.""" from functools import partial import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.indexes import stack_indexes, concat_indexes, clean_index from vectorbtpro.base.reshaping import to_1d_array, to_2d_array from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.utils import checks from vectorbtpro.utils.config import resolve_dict, merge_dicts, HybridConfig from vectorbtpro.utils.execution import NoResult, NoResultsException, filter_out_no_results from vectorbtpro.utils.merging import MergeFunc __all__ = [ "concat_arrays", "row_stack_arrays", "column_stack_arrays", "concat_merge", "row_stack_merge", "column_stack_merge", "imageio_merge", "mixed_merge", ] __pdoc__ = {} def concat_arrays(*arrs: tp.MaybeSequence[tp.AnyArray]) -> tp.Array1d: """Concatenate arrays.""" if len(arrs) == 1: arrs = arrs[0] arrs = list(arrs) arrs = list(map(to_1d_array, arrs)) return np.concatenate(arrs) def row_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d: """Stack arrays along rows.""" if len(arrs) == 1: arrs = arrs[0] arrs = list(arrs) arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs)) return np.concatenate(arrs, axis=0) def column_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d: """Stack arrays along columns.""" if len(arrs) == 1: arrs = arrs[0] arrs = list(arrs) arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs)) common_shape = None can_concatenate = True for arr in arrs: if common_shape is None: common_shape = arr.shape if arr.shape != common_shape: can_concatenate = False continue if not (arr.ndim == 1 or (arr.ndim == 2 and arr.shape[1] == 1)): can_concatenate = False continue if can_concatenate: return np.concatenate(arrs, axis=0).reshape((len(arrs), common_shape[0])).T return np.concatenate(arrs, axis=1) def concat_merge( *objs, keys: tp.Optional[tp.Index] = None, filter_results: bool = True, raise_no_results: bool = True, wrap: tp.Optional[bool] = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLikeSequence = None, clean_index_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeTuple[tp.AnyArray]: """Merge multiple array-like objects through concatenation. Supports a sequence of tuples. If `wrap` is None, it will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None. If `wrap` is True, each array will be wrapped with Pandas Series and merged using `pd.concat`. Otherwise, arrays will be kept as-is and merged using `concat_arrays`. `wrap_kwargs` can be a dictionary or a list of dictionaries. If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap_reduced`. Keyword arguments `**kwargs` are passed to `pd.concat` only. !!! note All arrays are assumed to have the same type and dimensionality.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if len(objs) == 0: raise ValueError("No objects to be merged") if isinstance(objs[0], tuple): if len(objs[0]) == 1: out_tuple = ( concat_merge( list(map(lambda x: x[0], objs)), keys=keys, wrap=wrap, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ), ) else: out_tuple = tuple( map( lambda x: concat_merge( x, keys=keys, wrap=wrap, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ), zip(*objs), ) ) if checks.is_namedtuple(objs[0]): return type(objs[0])(*out_tuple) return type(objs[0])(out_tuple) if filter_results: try: objs, keys = filter_out_no_results(objs, keys=keys) except NoResultsException as e: if raise_no_results: raise e return NoResult if isinstance(objs[0], Wrapping): raise TypeError("Concatenating Wrapping instances is not supported") if wrap_kwargs is None: wrap_kwargs = {} if wrap is None: wrap = isinstance(objs[0], pd.Series) or wrapper is not None or keys is not None or len(wrap_kwargs) > 0 if not checks.is_complex_iterable(objs[0]): if wrap: if keys is not None and isinstance(keys[0], pd.Index): if len(keys) == 1: keys = keys[0] else: keys = concat_indexes( *keys, index_concat_method="append", clean_index_kwargs=clean_index_kwargs, verify_integrity=False, axis=0, ) wrap_kwargs = merge_dicts(dict(index=keys), wrap_kwargs) return pd.Series(objs, **wrap_kwargs) return np.asarray(objs) if isinstance(objs[0], pd.Index): objs = list(map(lambda x: x.to_series(), objs)) default_index = True if not isinstance(objs[0], pd.Series): if isinstance(objs[0], pd.DataFrame): raise ValueError("Use row stacking for concatenating DataFrames") if wrap: new_objs = [] for i, obj in enumerate(objs): _wrap_kwargs = resolve_dict(wrap_kwargs, i) if wrapper is not None: if "force_1d" not in _wrap_kwargs: _wrap_kwargs["force_1d"] = True new_objs.append(wrapper.wrap_reduced(obj, **_wrap_kwargs)) else: new_objs.append(pd.Series(obj, **_wrap_kwargs)) if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True): default_index = False objs = new_objs if not wrap: return concat_arrays(objs) if keys is not None and isinstance(keys[0], pd.Index): new_obj = pd.concat(objs, axis=0, **kwargs) if len(keys) == 1: keys = keys[0] else: keys = concat_indexes( *keys, index_concat_method="append", verify_integrity=False, axis=0, ) if default_index: new_obj.index = keys else: new_obj.index = stack_indexes((keys, new_obj.index)) else: new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs) if clean_index_kwargs is None: clean_index_kwargs = {} new_obj.index = clean_index(new_obj.index, **clean_index_kwargs) return new_obj def row_stack_merge( *objs, keys: tp.Optional[tp.Index] = None, filter_results: bool = True, raise_no_results: bool = True, wrap: tp.Union[None, str, bool] = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLikeSequence = None, clean_index_kwargs: tp.KwargsLikeSequence = None, **kwargs, ) -> tp.MaybeTuple[tp.AnyArray]: """Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through row stacking. Supports a sequence of tuples. Argument `wrap` supports the following options: * None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None * True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions) * 'sr', 'series': each array will be wrapped with Pandas Series * 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame Without wrapping, arrays will be kept as-is and merged using `row_stack_arrays`. Argument `wrap_kwargs` can be a dictionary or a list of dictionaries. If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. Keyword arguments `**kwargs` are passed to `pd.concat` and `vectorbtpro.base.wrapping.Wrapping.row_stack` only. !!! note All arrays are assumed to have the same type and dimensionality.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if len(objs) == 0: raise ValueError("No objects to be merged") if isinstance(objs[0], tuple): if len(objs[0]) == 1: out_tuple = ( row_stack_merge( list(map(lambda x: x[0], objs)), keys=keys, wrap=wrap, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ), ) else: out_tuple = tuple( map( lambda x: row_stack_merge( x, keys=keys, wrap=wrap, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ), zip(*objs), ) ) if checks.is_namedtuple(objs[0]): return type(objs[0])(*out_tuple) return type(objs[0])(out_tuple) if filter_results: try: objs, keys = filter_out_no_results(objs, keys=keys) except NoResultsException as e: if raise_no_results: raise e return NoResult if isinstance(objs[0], Wrapping): kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs) return type(objs[0]).row_stack(objs, **kwargs) if wrap_kwargs is None: wrap_kwargs = {} if wrap is None: wrap = ( isinstance(objs[0], (pd.Series, pd.DataFrame)) or wrapper is not None or keys is not None or len(wrap_kwargs) > 0 ) if isinstance(objs[0], pd.Index): objs = list(map(lambda x: x.to_series(), objs)) default_index = True if not isinstance(objs[0], (pd.Series, pd.DataFrame)): if isinstance(wrap, str) or wrap: new_objs = [] for i, obj in enumerate(objs): _wrap_kwargs = resolve_dict(wrap_kwargs, i) if wrapper is not None: new_objs.append(wrapper.wrap(obj, **_wrap_kwargs)) else: if not isinstance(wrap, str): if isinstance(obj, np.ndarray): ndim = obj.ndim else: ndim = np.asarray(obj).ndim if ndim == 1: wrap = "series" else: wrap = "frame" if isinstance(wrap, str): if wrap.lower() in ("sr", "series"): new_objs.append(pd.Series(obj, **_wrap_kwargs)) elif wrap.lower() in ("df", "frame", "dataframe"): new_objs.append(pd.DataFrame(obj, **_wrap_kwargs)) else: raise ValueError(f"Invalid wrapping option: '{wrap}'") if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True): default_index = False objs = new_objs if not wrap: return row_stack_arrays(objs) if keys is not None and isinstance(keys[0], pd.Index): new_obj = pd.concat(objs, axis=0, **kwargs) if len(keys) == 1: keys = keys[0] else: keys = concat_indexes( *keys, index_concat_method="append", verify_integrity=False, axis=0, ) if default_index: new_obj.index = keys else: new_obj.index = stack_indexes((keys, new_obj.index)) else: new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs) if clean_index_kwargs is None: clean_index_kwargs = {} new_obj.index = clean_index(new_obj.index, **clean_index_kwargs) return new_obj def column_stack_merge( *objs, reset_index: tp.Union[None, bool, str] = None, fill_value: tp.Scalar = np.nan, keys: tp.Optional[tp.Index] = None, filter_results: bool = True, raise_no_results: bool = True, wrap: tp.Union[None, str, bool] = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLikeSequence = None, clean_index_kwargs: tp.KwargsLikeSequence = None, **kwargs, ) -> tp.MaybeTuple[tp.AnyArray]: """Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through column stacking. Supports a sequence of tuples. Argument `wrap` supports the following options: * None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None * True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions) * 'sr', 'series': each array will be wrapped with Pandas Series * 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame Without wrapping, arrays will be kept as-is and merged using `column_stack_arrays`. Argument `wrap_kwargs` can be a dictionary or a list of dictionaries. If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. Keyword arguments `**kwargs` are passed to `pd.concat` and `vectorbtpro.base.wrapping.Wrapping.column_stack` only. Argument `reset_index` supports the following options: * False or None: Keep original index of each object * True or 'from_start': Reset index of each object and align them at start * 'from_end': Reset index of each object and align them at end Options above work on Pandas, NumPy, and `vectorbtpro.base.wrapping.Wrapping` instances. !!! note All arrays are assumed to have the same type and dimensionality.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if len(objs) == 0: raise ValueError("No objects to be merged") if isinstance(reset_index, bool): if reset_index: reset_index = "from_start" else: reset_index = None if isinstance(objs[0], tuple): if len(objs[0]) == 1: out_tuple = ( column_stack_merge( list(map(lambda x: x[0], objs)), reset_index=reset_index, keys=keys, wrap=wrap, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ), ) else: out_tuple = tuple( map( lambda x: column_stack_merge( x, reset_index=reset_index, keys=keys, wrap=wrap, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ), zip(*objs), ) ) if checks.is_namedtuple(objs[0]): return type(objs[0])(*out_tuple) return type(objs[0])(out_tuple) if filter_results: try: objs, keys = filter_out_no_results(objs, keys=keys) except NoResultsException as e: if raise_no_results: raise e return NoResult if isinstance(objs[0], Wrapping): if reset_index is not None: max_length = max(map(lambda x: x.wrapper.shape[0], objs)) new_objs = [] for obj in objs: if isinstance(reset_index, str) and reset_index.lower() == "from_start": new_index = pd.RangeIndex(stop=obj.wrapper.shape[0]) new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index)) elif isinstance(reset_index, str) and reset_index.lower() == "from_end": new_index = pd.RangeIndex(start=max_length - obj.wrapper.shape[0], stop=max_length) new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index)) else: raise ValueError(f"Invalid index resetting option: '{reset_index}'") new_objs.append(new_obj) objs = new_objs kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs) return type(objs[0]).column_stack(objs, **kwargs) if wrap_kwargs is None: wrap_kwargs = {} if wrap is None: wrap = ( isinstance(objs[0], (pd.Series, pd.DataFrame)) or wrapper is not None or keys is not None or len(wrap_kwargs) > 0 ) if isinstance(objs[0], pd.Index): objs = list(map(lambda x: x.to_series(), objs)) default_columns = True if not isinstance(objs[0], (pd.Series, pd.DataFrame)): if isinstance(wrap, str) or wrap: new_objs = [] for i, obj in enumerate(objs): _wrap_kwargs = resolve_dict(wrap_kwargs, i) if wrapper is not None: new_objs.append(wrapper.wrap(obj, **_wrap_kwargs)) else: if not isinstance(wrap, str): if isinstance(obj, np.ndarray): ndim = obj.ndim else: ndim = np.asarray(obj).ndim if ndim == 1: wrap = "series" else: wrap = "frame" if isinstance(wrap, str): if wrap.lower() in ("sr", "series"): new_objs.append(pd.Series(obj, **_wrap_kwargs)) elif wrap.lower() in ("df", "frame", "dataframe"): new_objs.append(pd.DataFrame(obj, **_wrap_kwargs)) else: raise ValueError(f"Invalid wrapping option: '{wrap}'") if ( default_columns and isinstance(new_objs[-1], pd.DataFrame) and not checks.is_default_index(new_objs[-1].columns, check_names=True) ): default_columns = False objs = new_objs if not wrap: if reset_index is not None: min_n_rows = None max_n_rows = None n_cols = 0 new_objs = [] for obj in objs: new_obj = to_2d_array(obj) new_objs.append(new_obj) if min_n_rows is None or new_obj.shape[0] < min_n_rows: min_n_rows = new_obj.shape[0] if max_n_rows is None or new_obj.shape[0] > min_n_rows: max_n_rows = new_obj.shape[0] n_cols += new_obj.shape[1] if min_n_rows == max_n_rows: return column_stack_arrays(new_objs) new_obj = np.full((max_n_rows, n_cols), fill_value) start_col = 0 for obj in new_objs: end_col = start_col + obj.shape[1] if isinstance(reset_index, str) and reset_index.lower() == "from_start": new_obj[: len(obj), start_col:end_col] = obj elif isinstance(reset_index, str) and reset_index.lower() == "from_end": new_obj[-len(obj) :, start_col:end_col] = obj else: raise ValueError(f"Invalid index resetting option: '{reset_index}'") start_col = end_col return new_obj return column_stack_arrays(objs) if reset_index is not None: max_length = max(map(len, objs)) new_objs = [] for obj in objs: new_obj = obj.copy(deep=False) if isinstance(reset_index, str) and reset_index.lower() == "from_start": new_obj.index = pd.RangeIndex(stop=len(new_obj)) elif isinstance(reset_index, str) and reset_index.lower() == "from_end": new_obj.index = pd.RangeIndex(start=max_length - len(new_obj), stop=max_length) else: raise ValueError(f"Invalid index resetting option: '{reset_index}'") new_objs.append(new_obj) objs = new_objs kwargs = merge_dicts(dict(sort=True), kwargs) if keys is not None and isinstance(keys[0], pd.Index): new_obj = pd.concat(objs, axis=1, **kwargs) if len(keys) == 1: keys = keys[0] else: keys = concat_indexes( *keys, index_concat_method="append", verify_integrity=False, axis=1, ) if default_columns: new_obj.columns = keys else: new_obj.columns = stack_indexes((keys, new_obj.columns)) else: new_obj = pd.concat(objs, axis=1, keys=keys, **kwargs) if clean_index_kwargs is None: clean_index_kwargs = {} new_obj.columns = clean_index(new_obj.columns, **clean_index_kwargs) return new_obj def imageio_merge( *objs, keys: tp.Optional[tp.Index] = None, filter_results: bool = True, raise_no_results: bool = True, to_image_kwargs: tp.KwargsLike = None, imread_kwargs: tp.KwargsLike = None, **imwrite_kwargs, ) -> tp.MaybeTuple[tp.Union[None, bytes]]: """Merge multiple figure-like objects by writing them with `imageio`. Keyword arguments `to_image_kwargs` are passed to `plotly.graph_objects.Figure.to_image`. Keyword arguments `imread_kwargs` and `**imwrite_kwargs` are passed to `imageio.imread` and `imageio.imwrite` respectively. Keys are not used in any way.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go import imageio.v3 as iio if len(objs) == 1: objs = objs[0] objs = list(objs) if len(objs) == 0: raise ValueError("No objects to be merged") if isinstance(objs[0], tuple): if len(objs[0]) == 1: out_tuple = ( imageio_merge( list(map(lambda x: x[0], objs)), keys=keys, imread_kwargs=imread_kwargs, to_image_kwargs=to_image_kwargs, **imwrite_kwargs, ), ) else: out_tuple = tuple( map( lambda x: imageio_merge( x, keys=keys, imread_kwargs=imread_kwargs, to_image_kwargs=to_image_kwargs, **imwrite_kwargs, ), zip(*objs), ) ) if checks.is_namedtuple(objs[0]): return type(objs[0])(*out_tuple) return type(objs[0])(out_tuple) if filter_results: try: objs, keys = filter_out_no_results(objs, keys=keys) except NoResultsException as e: if raise_no_results: raise e return NoResult if imread_kwargs is None: imread_kwargs = {} if to_image_kwargs is None: to_image_kwargs = {} frames = [] for obj in objs: if obj is None: continue if isinstance(obj, (go.Figure, go.FigureWidget)): obj = obj.to_image(**to_image_kwargs) if not isinstance(obj, np.ndarray): obj = iio.imread(obj, **imread_kwargs) frames.append(obj) return iio.imwrite(image=frames, **imwrite_kwargs) def mixed_merge( *objs, merge_funcs: tp.Optional[tp.MergeFuncLike] = None, mixed_kwargs: tp.Optional[tp.Sequence[tp.KwargsLike]] = None, **kwargs, ) -> tp.MaybeTuple[tp.AnyArray]: """Merge objects of mixed types.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if len(objs) == 0: raise ValueError("No objects to be merged") if merge_funcs is None: raise ValueError("Merging functions or their names are required") if not isinstance(objs[0], tuple): raise ValueError("Mixed merging must be applied on tuples") outputs = [] for i, output_objs in enumerate(zip(*objs)): output_objs = list(output_objs) merge_func = resolve_merge_func(merge_funcs[i]) if merge_func is None: outputs.append(output_objs) else: if mixed_kwargs is None: _kwargs = kwargs else: _kwargs = merge_dicts(kwargs, mixed_kwargs[i]) output = merge_func(output_objs, **_kwargs) outputs.append(output) return tuple(outputs) merge_func_config = HybridConfig( dict( concat=concat_merge, row_stack=row_stack_merge, column_stack=column_stack_merge, reset_column_stack=partial(column_stack_merge, reset_index=True), from_start_column_stack=partial(column_stack_merge, reset_index="from_start"), from_end_column_stack=partial(column_stack_merge, reset_index="from_end"), imageio=imageio_merge, ) ) """_""" __pdoc__[ "merge_func_config" ] = f"""Config for merging functions. ```python {merge_func_config.prettify()} ``` """ def resolve_merge_func(merge_func: tp.MergeFuncLike) -> tp.Optional[tp.Callable]: """Resolve a merging function into a callable. If a string, looks up into `merge_func_config`. If a sequence, uses `mixed_merge` with `merge_funcs=merge_func`. If an instance of `vectorbtpro.utils.merging.MergeFunc`, calls `vectorbtpro.utils.merging.MergeFunc.resolve_merge_func` to get the actual callable.""" if merge_func is None: return None if isinstance(merge_func, str): if merge_func.lower() not in merge_func_config: raise ValueError(f"Invalid merging function name: '{merge_func}'") return merge_func_config[merge_func.lower()] if checks.is_sequence(merge_func): return partial(mixed_merge, merge_funcs=merge_func) if isinstance(merge_func, MergeFunc): return merge_func.resolve_merge_func() return merge_func def is_merge_func_from_config(merge_func: tp.MergeFuncLike) -> bool: """Return whether the merging function can be found in `merge_func_config`.""" if merge_func is None: return False if isinstance(merge_func, str): return merge_func.lower() in merge_func_config if checks.is_sequence(merge_func): return all(map(is_merge_func_from_config, merge_func)) if isinstance(merge_func, MergeFunc): return is_merge_func_from_config(merge_func.merge_func) return False # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for preparing arguments.""" import inspect import string from collections import defaultdict from datetime import timedelta, time from functools import cached_property as cachedproperty from pathlib import Path import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.decorators import override_arg_config, attach_arg_properties from vectorbtpro.base.indexes import repeat_index from vectorbtpro.base.indexing import index_dict, IdxSetter, IdxSetterFactory, IdxRecords from vectorbtpro.base.merging import concat_arrays, column_stack_arrays from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.reshaping import BCO, Default, Ref, broadcast from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.attr_ import get_dict_attr from vectorbtpro.utils.config import Configured from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig from vectorbtpro.utils.cutting import suggest_module_path, cut_and_save_func from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.module_ import import_module_from_path from vectorbtpro.utils.params import Param from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.path_ import remove_dir from vectorbtpro.utils.random_ import set_seed from vectorbtpro.utils.template import CustomTemplate, RepFunc, substitute_templates __all__ = [ "BasePreparer", ] __pdoc__ = {} base_arg_config = ReadonlyConfig( dict( broadcast_named_args=dict(is_dict=True), broadcast_kwargs=dict(is_dict=True), template_context=dict(is_dict=True), seed=dict(), jitted=dict(), chunked=dict(), staticized=dict(), records=dict(), ) ) """_""" __pdoc__[ "base_arg_config" ] = f"""Argument config for `BasePreparer`. ```python {base_arg_config.prettify()} ``` """ class MetaBasePreparer(type(Configured)): """Metaclass for `BasePreparer`.""" @property def arg_config(cls) -> Config: """Argument config.""" return cls._arg_config @attach_arg_properties @override_arg_config(base_arg_config) class BasePreparer(Configured, metaclass=MetaBasePreparer): """Base class for preparing target functions and arguments. !!! warning Most properties are force-cached - create a new instance to override any attribute.""" _expected_keys_mode: tp.ExpectedKeysMode = "disable" _writeable_attrs: tp.WriteableAttrs = {"_arg_config"} _settings_path: tp.SettingsPath = None def __init__(self, arg_config: tp.KwargsLike = None, **kwargs) -> None: Configured.__init__(self, arg_config=arg_config, **kwargs) # Copy writeable attrs self._arg_config = type(self)._arg_config.copy() if arg_config is not None: self._arg_config = merge_dicts(self._arg_config, arg_config) _arg_config: tp.ClassVar[Config] = HybridConfig() @property def arg_config(self) -> Config: """Argument config of `${cls_name}`. ```python ${arg_config} ``` """ return self._arg_config @classmethod def map_enum_value(cls, value: tp.ArrayLike, look_for_type: tp.Optional[type] = None, **kwargs) -> tp.ArrayLike: """Map enumerated value(s).""" if look_for_type is not None: if isinstance(value, look_for_type): return map_enum_fields(value, **kwargs) return value if isinstance(value, (CustomTemplate, Ref)): return value if isinstance(value, (Param, BCO, Default)): attr_dct = value.asdict() if isinstance(value, Param) and attr_dct["map_template"] is None: attr_dct["map_template"] = RepFunc(lambda values: cls.map_enum_value(values, **kwargs)) elif not isinstance(value, Param): attr_dct["value"] = cls.map_enum_value(attr_dct["value"], **kwargs) return type(value)(**attr_dct) if isinstance(value, index_dict): return index_dict({k: cls.map_enum_value(v, **kwargs) for k, v in value.items()}) if isinstance(value, IdxSetterFactory): value = value.get() if not isinstance(value, IdxSetter): raise ValueError("Index setter factory must return exactly one index setter") if isinstance(value, IdxSetter): return IdxSetter([(k, cls.map_enum_value(v, **kwargs)) for k, v in value.idx_items]) return map_enum_fields(value, **kwargs) @classmethod def prepare_td_obj(cls, td_obj: object, old_as_keys: bool = True) -> object: """Prepare a timedelta object for broadcasting.""" if isinstance(td_obj, Param): return td_obj.map_value(cls.prepare_td_obj, old_as_keys=old_as_keys) if isinstance(td_obj, (str, timedelta, pd.DateOffset, pd.Timedelta)): td_obj = dt.to_timedelta64(td_obj) elif isinstance(td_obj, pd.Index): td_obj = td_obj.values return td_obj @classmethod def prepare_dt_obj( cls, dt_obj: object, old_as_keys: bool = True, last_before: tp.Optional[bool] = None, ) -> object: """Prepare a datetime object for broadcasting.""" if isinstance(dt_obj, Param): return dt_obj.map_value(cls.prepare_dt_obj, old_as_keys=old_as_keys) if isinstance(dt_obj, (str, time, timedelta, pd.DateOffset, pd.Timedelta)): def _apply_last_before(source_index, target_index, source_freq): resampler = Resampler(source_index, target_index, source_freq=source_freq) last_indices = resampler.last_before_target_index(incl_source=False) source_rbound_ns = resampler.source_rbound_index.vbt.to_ns() return np.where(last_indices != -1, source_rbound_ns[last_indices], -1) def _to_dt(wrapper, _dt_obj=dt_obj, _last_before=last_before): if _last_before is None: _last_before = False _dt_obj = dt.try_align_dt_to_index(_dt_obj, wrapper.index) source_index = wrapper.index[wrapper.index < _dt_obj] target_index = repeat_index(pd.Index([_dt_obj]), len(source_index)) if _last_before: target_ns = _apply_last_before(source_index, target_index, wrapper.freq) else: target_ns = target_index.vbt.to_ns() if len(target_ns) < len(wrapper.index): target_ns = concat_arrays((target_ns, np.full(len(wrapper.index) - len(target_ns), -1))) return target_ns def _to_td(wrapper, _dt_obj=dt_obj, _last_before=last_before): if _last_before is None: _last_before = True target_index = wrapper.index.vbt.to_period_ts(dt.to_freq(_dt_obj), shift=True) if _last_before: return _apply_last_before(wrapper.index, target_index, wrapper.freq) return target_index.vbt.to_ns() def _to_time(wrapper, _dt_obj=dt_obj, _last_before=last_before): if _last_before is None: _last_before = False floor_index = wrapper.index.floor("1d") + dt.time_to_timedelta(_dt_obj) target_index = floor_index.where(wrapper.index < floor_index, floor_index + pd.Timedelta(days=1)) if _last_before: return _apply_last_before(wrapper.index, target_index, wrapper.freq) return target_index.vbt.to_ns() dt_obj_dt_template = RepFunc(_to_dt) dt_obj_td_template = RepFunc(_to_td) dt_obj_time_template = RepFunc(_to_time) if isinstance(dt_obj, str): try: time.fromisoformat(dt_obj) dt_obj = dt_obj_time_template except Exception as e: try: dt.to_freq(dt_obj) dt_obj = dt_obj_td_template except Exception as e: dt_obj = dt_obj_dt_template elif isinstance(dt_obj, time): dt_obj = dt_obj_time_template else: dt_obj = dt_obj_td_template elif isinstance(dt_obj, pd.Index): dt_obj = dt_obj.values return dt_obj def get_raw_arg_default(self, arg_name: str, is_dict: bool = False) -> tp.Any: """Get raw argument default.""" if self._settings_path is None: if is_dict: return {} return None value = self.get_setting(arg_name) if is_dict and value is None: return {} return value def get_raw_arg(self, arg_name: str, is_dict: bool = False, has_default: bool = True) -> tp.Any: """Get raw argument.""" value = self.config.get(arg_name, None) if is_dict: if has_default: return merge_dicts(self.get_raw_arg_default(arg_name), value) if value is None: return {} return value if value is None and has_default: return self.get_raw_arg_default(arg_name) return value @cachedproperty def idx_setters(self) -> tp.Optional[tp.Dict[tp.Label, IdxSetter]]: """Index setters from resolving the argument `records`.""" arg_config = self.arg_config["records"] records = self.get_raw_arg( "records", is_dict=arg_config.get("is_dict", False), has_default=arg_config.get("has_default", True), ) if records is None: return None if not isinstance(records, IdxRecords): records = IdxRecords(records) idx_setters = records.get() for k in idx_setters: if k in self.arg_config and not self.arg_config[k].get("broadcast", False): raise ValueError(f"Field {k} is not broadcastable and cannot be included in records") rename_fields = arg_config.get("rename_fields", {}) new_idx_setters = {} for k, v in idx_setters.items(): if k in rename_fields: k = rename_fields[k] new_idx_setters[k] = v return new_idx_setters def get_arg_default(self, arg_name: str) -> tp.Any: """Get argument default according to the argument config.""" arg_config = self.arg_config[arg_name] arg = self.get_raw_arg_default( arg_name, is_dict=arg_config.get("is_dict", False), ) if arg is not None: if len(arg_config.get("map_enum_kwargs", {})) > 0: arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"]) if arg_config.get("is_td", False): arg = self.prepare_td_obj( arg, old_as_keys=arg_config.get("old_as_keys", True), ) if arg_config.get("is_dt", False): arg = self.prepare_dt_obj( arg, old_as_keys=arg_config.get("old_as_keys", True), last_before=arg_config.get("last_before", None), ) return arg def get_arg(self, arg_name: str, use_idx_setter: bool = True, use_default: bool = True) -> tp.Any: """Get mapped argument according to the argument config.""" arg_config = self.arg_config[arg_name] if use_idx_setter and self.idx_setters is not None and arg_name in self.idx_setters: arg = self.idx_setters[arg_name] else: arg = self.get_raw_arg( arg_name, is_dict=arg_config.get("is_dict", False), has_default=arg_config.get("has_default", True) if use_default else False, ) if arg is not None: if len(arg_config.get("map_enum_kwargs", {})) > 0: arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"]) if arg_config.get("is_td", False): arg = self.prepare_td_obj(arg) if arg_config.get("is_dt", False): arg = self.prepare_dt_obj(arg, last_before=arg_config.get("last_before", None)) return arg def __getitem__(self, arg_name) -> tp.Any: return self.get_arg(arg_name) def __iter__(self): raise TypeError(f"'{type(self).__name__}' object is not iterable") @classmethod def prepare_td_arr(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike: """Prepare a timedelta array.""" if td_arr.dtype == object: if td_arr.ndim in (0, 1): td_arr = pd.to_timedelta(td_arr) if isinstance(td_arr, pd.Timedelta): td_arr = td_arr.to_timedelta64() else: td_arr = td_arr.values else: td_arr_cols = [] for col in range(td_arr.shape[1]): td_arr_col = pd.to_timedelta(td_arr[:, col]) td_arr_cols.append(td_arr_col.values) td_arr = column_stack_arrays(td_arr_cols) return td_arr @classmethod def prepare_dt_arr(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike: """Prepare a datetime array.""" if dt_arr.dtype == object: if dt_arr.ndim in (0, 1): dt_arr = pd.to_datetime(dt_arr).tz_localize(None) if isinstance(dt_arr, pd.Timestamp): dt_arr = dt_arr.to_datetime64() else: dt_arr = dt_arr.values else: dt_arr_cols = [] for col in range(dt_arr.shape[1]): dt_arr_col = pd.to_datetime(dt_arr[:, col]).tz_localize(None) dt_arr_cols.append(dt_arr_col.values) dt_arr = column_stack_arrays(dt_arr_cols) return dt_arr @classmethod def td_arr_to_ns(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike: """Prepare a timedelta array and convert it to nanoseconds.""" return dt.to_ns(cls.prepare_td_arr(td_arr)) @classmethod def dt_arr_to_ns(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike: """Prepare a datetime array and convert it to nanoseconds.""" return dt.to_ns(cls.prepare_dt_arr(dt_arr)) def prepare_post_arg(self, arg_name: str, value: tp.Optional[tp.ArrayLike] = None) -> object: """Prepare an argument after broadcasting and/or template substitution.""" if value is None: if arg_name in self.post_args: arg = self.post_args[arg_name] else: arg = getattr(self, "_pre_" + arg_name) else: arg = value if arg is not None: arg_config = self.arg_config[arg_name] if arg_config.get("substitute_templates", False): arg = substitute_templates(arg, self.template_context, eval_id=arg_name) if "map_enum_kwargs" in arg_config: arg = map_enum_fields(arg, **arg_config["map_enum_kwargs"]) if arg_config.get("is_td", False): arg = self.td_arr_to_ns(arg) if arg_config.get("is_dt", False): arg = self.dt_arr_to_ns(arg) if "type" in arg_config: checks.assert_instance_of(arg, arg_config["type"], arg_name=arg_name) if "subdtype" in arg_config: checks.assert_subdtype(arg, arg_config["subdtype"], arg_name=arg_name) return arg @classmethod def adapt_staticized_to_udf(cls, staticized: tp.Kwargs, func: tp.Union[str, tp.Callable], func_name: str) -> None: """Adapt `staticized` dictionary to a UDF.""" target_func_module = inspect.getmodule(staticized["func"]) if isinstance(func, tuple): func, actual_func_name = func else: actual_func_name = None if isinstance(func, (str, Path)): if actual_func_name is None: actual_func_name = func_name if isinstance(func, str) and not func.endswith(".py") and hasattr(target_func_module, func): staticized[f"{func_name}_block"] = func return None func = Path(func) module_path = func.resolve() else: if actual_func_name is None: actual_func_name = func.__name__ if inspect.getmodule(func) == target_func_module: staticized[f"{func_name}_block"] = actual_func_name return None module = inspect.getmodule(func) if not hasattr(module, "__file__"): raise TypeError(f"{func_name} must be defined in a Python file") module_path = Path(module.__file__).resolve() if "import_lines" not in staticized: staticized["import_lines"] = [] reload = staticized.get("reload", False) staticized["import_lines"].extend( [ f'{func_name}_path = r"{module_path}"', f"globals().update(vbt.import_module_from_path({func_name}_path).__dict__, reload={reload})", ] ) if actual_func_name != func_name: staticized["import_lines"].append(f"{func_name} = {actual_func_name}") @classmethod def find_target_func(cls, target_func_name: str) -> tp.Callable: """Find target function by its name.""" raise NotImplementedError @classmethod def resolve_dynamic_target_func(cls, target_func_name: str, staticized: tp.KwargsLike) -> tp.Callable: """Resolve a dynamic target function.""" if staticized is None: func = cls.find_target_func(target_func_name) else: if isinstance(staticized, dict): staticized = dict(staticized) module_path = suggest_module_path( staticized.get("suggest_fname", target_func_name), path=staticized.pop("path", None), mkdir_kwargs=staticized.get("mkdir_kwargs", None), ) if "new_func_name" not in staticized: staticized["new_func_name"] = target_func_name if staticized.pop("override", False) or not module_path.exists(): if "skip_func" not in staticized: def _skip_func(out_lines, func_name): to_skip = lambda x: f"def {func_name}" in x or x.startswith(f"{func_name}_path =") return any(map(to_skip, out_lines)) staticized["skip_func"] = _skip_func module_path = cut_and_save_func(path=module_path, **staticized) if staticized.get("clear_cache", True): remove_dir(module_path.parent / "__pycache__", with_contents=True, missing_ok=True) reload = staticized.pop("reload", False) module = import_module_from_path(module_path, reload=reload) func = getattr(module, staticized["new_func_name"]) else: func = staticized return func def set_seed(self) -> None: """Set seed.""" seed = self.seed if seed is not None: set_seed(seed) # ############# Before broadcasting ############# # @cachedproperty def _pre_template_context(self) -> tp.Kwargs: """Argument `template_context` before broadcasting.""" return merge_dicts(dict(preparer=self), self["template_context"]) # ############# Broadcasting ############# # @cachedproperty def pre_args(self) -> tp.Kwargs: """Arguments before broadcasting.""" pre_args = dict() for k, v in self.arg_config.items(): if v.get("broadcast", False): pre_args[k] = getattr(self, "_pre_" + k) return pre_args @cachedproperty def args_to_broadcast(self) -> dict: """Arguments to broadcast.""" return merge_dicts(self.idx_setters, self.pre_args, self.broadcast_named_args) @cachedproperty def def_broadcast_kwargs(self) -> tp.Kwargs: """Default keyword arguments for broadcasting.""" return dict( to_pd=False, keep_flex=dict(cash_earnings=self.keep_inout_flex, _def=True), wrapper_kwargs=dict( freq=self._pre_freq, group_by=self.group_by, ), return_wrapper=True, template_context=self._pre_template_context, ) @cachedproperty def broadcast_kwargs(self) -> tp.Kwargs: """Argument `broadcast_kwargs`.""" arg_broadcast_kwargs = defaultdict(dict) for k, v in self.arg_config.items(): if v.get("broadcast", False): broadcast_kwargs = v.get("broadcast_kwargs", None) if broadcast_kwargs is None: broadcast_kwargs = {} for k2, v2 in broadcast_kwargs.items(): arg_broadcast_kwargs[k2][k] = v2 for k in self.args_to_broadcast: new_fill_value = None if k in self.pre_args: fill_default = self.arg_config[k].get("fill_default", True) if self.idx_setters is not None and k in self.idx_setters: new_fill_value = self.get_arg(k, use_idx_setter=False, use_default=fill_default) elif fill_default and self.arg_config[k].get("has_default", True): new_fill_value = self.get_arg_default(k) elif k in self.broadcast_named_args: if self.idx_setters is not None and k in self.idx_setters: new_fill_value = self.broadcast_named_args[k] if new_fill_value is not None: if not np.isscalar(new_fill_value): raise TypeError(f"Argument '{k}' (and its default) must be a scalar when also provided via records") if "reindex_kwargs" not in arg_broadcast_kwargs: arg_broadcast_kwargs["reindex_kwargs"] = {} if k not in arg_broadcast_kwargs["reindex_kwargs"]: arg_broadcast_kwargs["reindex_kwargs"][k] = {} arg_broadcast_kwargs["reindex_kwargs"][k]["fill_value"] = new_fill_value return merge_dicts( self.def_broadcast_kwargs, dict(arg_broadcast_kwargs), self["broadcast_kwargs"], ) @cachedproperty def broadcast_result(self) -> tp.Any: """Result of broadcasting.""" return broadcast(self.args_to_broadcast, **self.broadcast_kwargs) @cachedproperty def post_args(self) -> tp.Kwargs: """Arguments after broadcasting.""" return self.broadcast_result[0] @cachedproperty def post_broadcast_named_args(self) -> tp.Kwargs: """Custom arguments after broadcasting.""" if self.broadcast_named_args is None: return dict() post_broadcast_named_args = dict() for k, v in self.post_args.items(): if k in self.broadcast_named_args: post_broadcast_named_args[k] = v elif self.idx_setters is not None and k in self.idx_setters and k not in self.pre_args: post_broadcast_named_args[k] = v return post_broadcast_named_args @cachedproperty def wrapper(self) -> ArrayWrapper: """Array wrapper.""" return self.broadcast_result[1] @cachedproperty def target_shape(self) -> tp.Shape: """Target shape.""" return self.wrapper.shape_2d @cachedproperty def index(self) -> tp.Array1d: """Index in nanosecond format.""" return self.wrapper.ns_index @cachedproperty def freq(self) -> int: """Frequency in nanosecond format.""" return self.wrapper.ns_freq # ############# Template substitution ############# # @cachedproperty def template_context(self) -> tp.Kwargs: """Argument `template_context`.""" builtin_args = {} for k, v in self.arg_config.items(): if v.get("broadcast", False): builtin_args[k] = getattr(self, k) return merge_dicts( dict( wrapper=self.wrapper, target_shape=self.target_shape, index=self.index, freq=self.freq, ), builtin_args, self.post_broadcast_named_args, self._pre_template_context, ) # ############# Result ############# # @cachedproperty def target_func(self) -> tp.Optional[tp.Callable]: """Target function.""" return None @cachedproperty def target_arg_map(self) -> tp.Kwargs: """Map of the target arguments to the preparer attributes.""" return dict() @cachedproperty def target_args(self) -> tp.Optional[tp.Kwargs]: """Arguments to be passed to the target function.""" if self.target_func is not None: target_arg_map = self.target_arg_map func_arg_names = get_func_arg_names(self.target_func) target_args = {} for k in func_arg_names: arg_attr = target_arg_map.get(k, k) if arg_attr is not None and hasattr(self, arg_attr): target_args[k] = getattr(self, arg_attr) return target_args return None # ############# Docs ############# # @classmethod def build_arg_config_doc(cls, source_cls: tp.Optional[type] = None) -> str: """Build argument config documentation.""" if source_cls is None: source_cls = BasePreparer return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "arg_config").__doc__)).substitute( {"arg_config": cls.arg_config.prettify(), "cls_name": cls.__name__}, ) @classmethod def override_arg_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: """Call this method on each subclass that overrides `BasePreparer.arg_config`.""" __pdoc__[cls.__name__ + ".arg_config"] = cls.build_arg_config_doc(source_cls=source_cls) BasePreparer.override_arg_config_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Functions for reshaping arrays. Reshape functions transform a Pandas object/NumPy array in some way.""" import functools import itertools import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base import indexes, wrapping, indexing from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.config import resolve_dict, merge_dicts from vectorbtpro.utils.params import combine_params, Param from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.template import CustomTemplate __all__ = [ "to_1d_shape", "to_2d_shape", "repeat_shape", "tile_shape", "to_1d_array", "to_2d_array", "to_2d_pr_array", "to_2d_pc_array", "to_1d_array_nb", "to_2d_array_nb", "to_2d_pr_array_nb", "to_2d_pc_array_nb", "broadcast_shapes", "broadcast_array_to", "broadcast_arrays", "repeat", "tile", "align_pd_arrays", "BCO", "Default", "Ref", "broadcast", "broadcast_to", ] def to_tuple_shape(shape: tp.ShapeLike) -> tp.Shape: """Convert a shape-like object to a tuple.""" if checks.is_int(shape): return (int(shape),) return tuple(shape) def to_1d_shape(shape: tp.ShapeLike) -> tp.Shape: """Convert a shape-like object to a 1-dim shape.""" shape = to_tuple_shape(shape) if len(shape) == 0: return (1,) if len(shape) == 1: return shape if len(shape) == 2 and shape[1] == 1: return (shape[0],) raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 1 dimension") def to_2d_shape(shape: tp.ShapeLike, expand_axis: int = 1) -> tp.Shape: """Convert a shape-like object to a 2-dim shape.""" shape = to_tuple_shape(shape) if len(shape) == 0: return 1, 1 if len(shape) == 1: if expand_axis == 1: return shape[0], 1 else: return shape[0], 0 if len(shape) == 2: return shape raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 2 dimensions") def repeat_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape: """Repeat shape `n` times along the specified axis.""" shape = to_tuple_shape(shape) if len(shape) <= axis: shape = tuple([shape[i] if i < len(shape) else 1 for i in range(axis + 1)]) return *shape[:axis], shape[axis] * n, *shape[axis + 1 :] def tile_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape: """Tile shape `n` times along the specified axis. Identical to `repeat_shape`. Exists purely for naming consistency.""" return repeat_shape(shape, n, axis=axis) def index_to_series(obj: tp.Index, reset_index: bool = False) -> tp.Series: """Convert Index to Series.""" if reset_index: return obj.to_series(index=pd.RangeIndex(stop=len(obj))) return obj.to_series() def index_to_frame(obj: tp.Index, reset_index: bool = False) -> tp.Frame: """Convert Index to DataFrame.""" if not isinstance(obj, pd.MultiIndex): return index_to_series(obj, reset_index=reset_index).to_frame() return obj.to_frame(index=not reset_index) def mapping_to_series(obj: tp.MappingLike) -> tp.Series: """Convert a mapping-like object to Series.""" if checks.is_namedtuple(obj): obj = obj._asdict() return pd.Series(obj) def to_any_array(obj: tp.ArrayLike, raw: bool = False, convert_index: bool = True) -> tp.AnyArray: """Convert any array-like object to an array. Pandas objects are kept as-is unless `raw` is True.""" from vectorbtpro.indicators.factory import IndicatorBase if isinstance(obj, IndicatorBase): obj = obj.main_output if not raw: if checks.is_any_array(obj): if convert_index and checks.is_index(obj): return index_to_series(obj) return obj if checks.is_mapping_like(obj): return mapping_to_series(obj) return np.asarray(obj) def to_pd_array(obj: tp.ArrayLike, convert_index: bool = True) -> tp.PandasArray: """Convert any array-like object to a Pandas object.""" from vectorbtpro.indicators.factory import IndicatorBase if isinstance(obj, IndicatorBase): obj = obj.main_output if checks.is_pandas(obj): if convert_index and checks.is_index(obj): return index_to_series(obj) return obj if checks.is_mapping_like(obj): return mapping_to_series(obj) obj = np.asarray(obj) if obj.ndim == 0: obj = obj[None] if obj.ndim == 1: return pd.Series(obj) if obj.ndim == 2: return pd.DataFrame(obj) raise ValueError("Wrong number of dimensions: cannot convert to Series or DataFrame") def soft_to_ndim(obj: tp.ArrayLike, ndim: int, raw: bool = False) -> tp.AnyArray: """Try to softly bring `obj` to the specified number of dimensions `ndim` (max 2).""" obj = to_any_array(obj, raw=raw) if ndim == 1: if obj.ndim == 2: if obj.shape[1] == 1: if checks.is_frame(obj): return obj.iloc[:, 0] return obj[:, 0] # downgrade if ndim == 2: if obj.ndim == 1: if checks.is_series(obj): return obj.to_frame() return obj[:, None] # upgrade return obj # do nothing def to_1d(obj: tp.ArrayLike, raw: bool = False) -> tp.AnyArray1d: """Reshape argument to one dimension. If `raw` is True, returns NumPy array. If 2-dim, will collapse along axis 1 (i.e., DataFrame with one column to Series).""" obj = to_any_array(obj, raw=raw) if obj.ndim == 2: if obj.shape[1] == 1: if checks.is_frame(obj): return obj.iloc[:, 0] return obj[:, 0] if obj.ndim == 1: return obj elif obj.ndim == 0: return obj.reshape((1,)) raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 1 dimension") to_1d_array = functools.partial(to_1d, raw=True) """`to_1d` with `raw` enabled.""" def to_2d(obj: tp.ArrayLike, raw: bool = False, expand_axis: int = 1) -> tp.AnyArray2d: """Reshape argument to two dimensions. If `raw` is True, returns NumPy array. If 1-dim, will expand along axis 1 (i.e., Series to DataFrame with one column).""" obj = to_any_array(obj, raw=raw) if obj.ndim == 2: return obj elif obj.ndim == 1: if checks.is_series(obj): if expand_axis == 0: return pd.DataFrame(obj.values[None, :], columns=obj.index) elif expand_axis == 1: return obj.to_frame() return np.expand_dims(obj, expand_axis) elif obj.ndim == 0: return obj.reshape((1, 1)) raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 2 dimensions") to_2d_array = functools.partial(to_2d, raw=True) """`to_2d` with `raw` enabled.""" to_2d_pr_array = functools.partial(to_2d_array, expand_axis=1) """`to_2d_array` with `expand_axis=1`.""" to_2d_pc_array = functools.partial(to_2d_array, expand_axis=0) """`to_2d_array` with `expand_axis=0`.""" @register_jitted(cache=True) def to_1d_array_nb(obj: tp.Array) -> tp.Array1d: """Resize array to one dimension.""" if obj.ndim == 0: return np.expand_dims(obj, axis=0) if obj.ndim == 1: return obj if obj.ndim == 2 and obj.shape[1] == 1: return obj[:, 0] raise ValueError("Array cannot be resized to one dimension") @register_jitted(cache=True) def to_2d_array_nb(obj: tp.Array, expand_axis: int = 1) -> tp.Array2d: """Resize array to two dimensions.""" if obj.ndim == 0: return np.expand_dims(np.expand_dims(obj, axis=0), axis=0) if obj.ndim == 1: return np.expand_dims(obj, axis=expand_axis) if obj.ndim == 2: return obj raise ValueError("Array cannot be resized to two dimensions") @register_jitted(cache=True) def to_2d_pr_array_nb(obj: tp.Array) -> tp.Array2d: """`to_2d_array_nb` with `expand_axis=1`.""" return to_2d_array_nb(obj, expand_axis=1) @register_jitted(cache=True) def to_2d_pc_array_nb(obj: tp.Array) -> tp.Array2d: """`to_2d_array_nb` with `expand_axis=0`.""" return to_2d_array_nb(obj, expand_axis=0) def to_dict(obj: tp.ArrayLike, orient: str = "dict") -> dict: """Convert object to dict.""" obj = to_pd_array(obj) if orient == "index_series": return {obj.index[i]: obj.iloc[i] for i in range(len(obj.index))} return obj.to_dict(orient) def repeat( obj: tp.ArrayLike, n: int, axis: int = 1, raw: bool = False, ignore_ranges: tp.Optional[bool] = None, ) -> tp.AnyArray: """Repeat `obj` `n` times along the specified axis.""" obj = to_any_array(obj, raw=raw) if axis == 0: if checks.is_pandas(obj): new_index = indexes.repeat_index(obj.index, n, ignore_ranges=ignore_ranges) return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=0), index=new_index) return np.repeat(obj, n, axis=0) elif axis == 1: obj = to_2d(obj) if checks.is_pandas(obj): new_columns = indexes.repeat_index(obj.columns, n, ignore_ranges=ignore_ranges) return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=1), columns=new_columns) return np.repeat(obj, n, axis=1) else: raise ValueError(f"Only axes 0 and 1 are supported, not {axis}") def tile( obj: tp.ArrayLike, n: int, axis: int = 1, raw: bool = False, ignore_ranges: tp.Optional[bool] = None, ) -> tp.AnyArray: """Tile `obj` `n` times along the specified axis.""" obj = to_any_array(obj, raw=raw) if axis == 0: if obj.ndim == 2: if checks.is_pandas(obj): new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges) return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (n, 1)), index=new_index) return np.tile(obj, (n, 1)) if checks.is_pandas(obj): new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges) return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, n), index=new_index) return np.tile(obj, n) elif axis == 1: obj = to_2d(obj) if checks.is_pandas(obj): new_columns = indexes.tile_index(obj.columns, n, ignore_ranges=ignore_ranges) return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (1, n)), columns=new_columns) return np.tile(obj, (1, n)) else: raise ValueError(f"Only axes 0 and 1 are supported, not {axis}") def broadcast_shapes( *shapes: tp.ArrayLike, axis: tp.Optional[tp.MaybeSequence[int]] = None, expand_axis: tp.Optional[tp.MaybeSequence[int]] = None, ) -> tp.Tuple[tp.Shape, ...]: """Broadcast shape-like objects using vectorbt's broadcasting rules.""" from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if expand_axis is None: expand_axis = broadcasting_cfg["expand_axis"] is_2d = False for i, shape in enumerate(shapes): shape = to_tuple_shape(shape) if len(shape) == 2: is_2d = True break new_shapes = [] for i, shape in enumerate(shapes): shape = to_tuple_shape(shape) if is_2d: if checks.is_sequence(expand_axis): _expand_axis = expand_axis[i] else: _expand_axis = expand_axis new_shape = to_2d_shape(shape, expand_axis=_expand_axis) else: new_shape = to_1d_shape(shape) if axis is not None: if checks.is_sequence(axis): _axis = axis[i] else: _axis = axis if _axis is not None: if _axis == 0: if is_2d: new_shape = (new_shape[0], 1) else: new_shape = (new_shape[0],) elif _axis == 1: if is_2d: new_shape = (1, new_shape[1]) else: new_shape = (1,) else: raise ValueError(f"Only axes 0 and 1 are supported, not {_axis}") new_shapes.append(new_shape) return tuple(np.broadcast_shapes(*new_shapes)) def broadcast_array_to( arr: tp.ArrayLike, target_shape: tp.ShapeLike, axis: tp.Optional[int] = None, expand_axis: tp.Optional[int] = None, ) -> tp.Array: """Broadcast an array-like object to a target shape using vectorbt's broadcasting rules.""" from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if expand_axis is None: expand_axis = broadcasting_cfg["expand_axis"] arr = np.asarray(arr) target_shape = to_tuple_shape(target_shape) if len(target_shape) not in (1, 2): raise ValueError(f"Target shape must have either 1 or 2 dimensions, not {len(target_shape)}") if len(target_shape) == 2: new_arr = to_2d_array(arr, expand_axis=expand_axis) else: new_arr = to_1d_array(arr) if axis is not None: if axis == 0: if len(target_shape) == 2: target_shape = (target_shape[0], new_arr.shape[1]) else: target_shape = (target_shape[0],) elif axis == 1: target_shape = (new_arr.shape[0], target_shape[1]) else: raise ValueError(f"Only axes 0 and 1 are supported, not {axis}") return np.broadcast_to(new_arr, target_shape) def broadcast_arrays( *arrs: tp.ArrayLike, target_shape: tp.Optional[tp.ShapeLike] = None, axis: tp.Optional[tp.MaybeSequence[int]] = None, expand_axis: tp.Optional[tp.MaybeSequence[int]] = None, ) -> tp.Tuple[tp.Array, ...]: """Broadcast array-like objects using vectorbt's broadcasting rules. Optionally to a target shape.""" if target_shape is None: shapes = [] for arr in arrs: shapes.append(np.asarray(arr).shape) target_shape = broadcast_shapes(*shapes, axis=axis, expand_axis=expand_axis) new_arrs = [] for i, arr in enumerate(arrs): if axis is not None: if checks.is_sequence(axis): _axis = axis[i] else: _axis = axis else: _axis = None if expand_axis is not None: if checks.is_sequence(expand_axis): _expand_axis = expand_axis[i] else: _expand_axis = expand_axis else: _expand_axis = None new_arr = broadcast_array_to(arr, target_shape, axis=_axis, expand_axis=_expand_axis) new_arrs.append(new_arr) return tuple(new_arrs) IndexFromLike = tp.Union[None, str, int, tp.Any] """Any object that can be coerced into a `index_from` argument.""" def broadcast_index( objs: tp.Sequence[tp.AnyArray], to_shape: tp.Shape, index_from: IndexFromLike = None, axis: int = 0, ignore_sr_names: tp.Optional[bool] = None, ignore_ranges: tp.Optional[bool] = None, check_index_names: tp.Optional[bool] = None, **clean_index_kwargs, ) -> tp.Optional[tp.Index]: """Produce a broadcast index/columns. Args: objs (iterable of array_like): Array-like objects. to_shape (tuple of int): Target shape. index_from (any): Broadcasting rule for this index/these columns. Accepts the following values: * 'keep' or None - keep the original index/columns of the objects in `objs` * 'stack' - stack different indexes/columns using `vectorbtpro.base.indexes.stack_indexes` * 'strict' - ensure that all Pandas objects have the same index/columns * 'reset' - reset any index/columns (they become a simple range) * integer - use the index/columns of the i-th object in `objs` * everything else will be converted to `pd.Index` axis (int): Set to 0 for index and 1 for columns. ignore_sr_names (bool): Whether to ignore Series names if they are in conflict. Conflicting Series names are those that are different but not None. ignore_ranges (bool): Whether to ignore indexes of type `pd.RangeIndex`. check_index_names (bool): See `vectorbtpro.utils.checks.is_index_equal`. **clean_index_kwargs: Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`. For defaults, see `vectorbtpro._settings.broadcasting`. !!! note Series names are treated as columns with a single element but without a name. If a column level without a name loses its meaning, better to convert Series to DataFrames with one column prior to broadcasting. If the name of a Series is not that important, better to drop it altogether by setting it to None. """ from vectorbtpro._settings import settings broadcasting_cfg = settings["broadcasting"] if ignore_sr_names is None: ignore_sr_names = broadcasting_cfg["ignore_sr_names"] if check_index_names is None: check_index_names = broadcasting_cfg["check_index_names"] index_str = "columns" if axis == 1 else "index" to_shape_2d = (to_shape[0], 1) if len(to_shape) == 1 else to_shape maxlen = to_shape_2d[1] if axis == 1 else to_shape_2d[0] new_index = None objs = list(objs) if index_from is None or (isinstance(index_from, str) and index_from.lower() == "keep"): return None if isinstance(index_from, int): if not checks.is_pandas(objs[index_from]): raise TypeError(f"Argument under index {index_from} must be a pandas object") new_index = indexes.get_index(objs[index_from], axis) elif isinstance(index_from, str): if index_from.lower() == "reset": new_index = pd.RangeIndex(start=0, stop=maxlen, step=1) elif index_from.lower() in ("stack", "strict"): last_index = None index_conflict = False for obj in objs: if checks.is_pandas(obj): index = indexes.get_index(obj, axis) if last_index is not None: if not checks.is_index_equal(index, last_index, check_names=check_index_names): index_conflict = True last_index = index continue if not index_conflict: new_index = last_index else: for obj in objs: if checks.is_pandas(obj): index = indexes.get_index(obj, axis) if axis == 1 and checks.is_series(obj) and ignore_sr_names: continue if checks.is_default_index(index): continue if new_index is None: new_index = index else: if checks.is_index_equal(index, new_index, check_names=check_index_names): continue if index_from.lower() == "strict": raise ValueError( f"Arrays have different index. Broadcasting {index_str} " f"is not allowed when {index_str}_from=strict" ) if len(index) != len(new_index): if len(index) > 1 and len(new_index) > 1: raise ValueError("Indexes could not be broadcast together") if len(index) > len(new_index): new_index = indexes.repeat_index(new_index, len(index), ignore_ranges=ignore_ranges) elif len(index) < len(new_index): index = indexes.repeat_index(index, len(new_index), ignore_ranges=ignore_ranges) new_index = indexes.stack_indexes([new_index, index], **clean_index_kwargs) else: raise ValueError(f"Invalid value '{index_from}' for {'columns' if axis == 1 else 'index'}_from") else: if not isinstance(index_from, pd.Index): index_from = pd.Index(index_from) new_index = index_from if new_index is not None: if maxlen > len(new_index): if isinstance(index_from, str) and index_from.lower() == "strict": raise ValueError(f"Broadcasting {index_str} is not allowed when {index_str}_from=strict") if maxlen > 1 and len(new_index) > 1: raise ValueError("Indexes could not be broadcast together") new_index = indexes.repeat_index(new_index, maxlen, ignore_ranges=ignore_ranges) else: new_index = pd.RangeIndex(start=0, stop=maxlen, step=1) return new_index def wrap_broadcasted( new_obj: tp.Array, old_obj: tp.Optional[tp.AnyArray] = None, axis: tp.Optional[int] = None, is_pd: bool = False, new_index: tp.Optional[tp.Index] = None, new_columns: tp.Optional[tp.Index] = None, ignore_ranges: tp.Optional[bool] = None, ) -> tp.AnyArray: """If the newly brodcasted array was originally a Pandas object, make it Pandas object again and assign it the newly broadcast index/columns.""" if is_pd: if axis == 0: new_columns = None elif axis == 1: new_index = None if old_obj is not None and checks.is_pandas(old_obj): if new_index is None: old_index = indexes.get_index(old_obj, 0) if old_obj.shape[0] == new_obj.shape[0]: new_index = old_index else: new_index = indexes.repeat_index(old_index, new_obj.shape[0], ignore_ranges=ignore_ranges) if new_columns is None: old_columns = indexes.get_index(old_obj, 1) new_ncols = new_obj.shape[1] if new_obj.ndim == 2 else 1 if len(old_columns) == new_ncols: new_columns = old_columns else: new_columns = indexes.repeat_index(old_columns, new_ncols, ignore_ranges=ignore_ranges) if new_obj.ndim == 2: return pd.DataFrame(new_obj, index=new_index, columns=new_columns) if new_columns is not None and len(new_columns) == 1: name = new_columns[0] if name == 0: name = None else: name = None return pd.Series(new_obj, index=new_index, name=name) return new_obj def align_pd_arrays( *objs: tp.AnyArray, align_index: bool = True, align_columns: bool = True, to_index: tp.Optional[tp.Index] = None, to_columns: tp.Optional[tp.Index] = None, axis: tp.Optional[tp.MaybeSequence[int]] = None, reindex_kwargs: tp.KwargsLikeSequence = None, ) -> tp.MaybeTuple[tp.ArrayLike]: """Align Pandas arrays against common index and/or column levels using reindexing and `vectorbtpro.base.indexes.align_indexes` respectively.""" objs = list(objs) if align_index: indexes_to_align = [] for i in range(len(objs)): if axis is not None: if checks.is_sequence(axis): _axis = axis[i] else: _axis = axis else: _axis = None if _axis in (None, 0): if checks.is_pandas(objs[i]): if not checks.is_default_index(objs[i].index): indexes_to_align.append(i) if (len(indexes_to_align) > 0 and to_index is not None) or len(indexes_to_align) > 1: if to_index is None: new_index = None index_changed = False for i in indexes_to_align: arg_index = objs[i].index if new_index is None: new_index = arg_index else: if not checks.is_index_equal(new_index, arg_index): if new_index.dtype != arg_index.dtype: raise ValueError("Indexes to be aligned must have the same data type") new_index = new_index.union(arg_index) index_changed = True else: new_index = to_index index_changed = True if index_changed: for i in indexes_to_align: if to_index is None or not checks.is_index_equal(objs[i].index, to_index): if objs[i].index.has_duplicates: raise ValueError(f"Index at position {i} contains duplicates") if not objs[i].index.is_monotonic_increasing: raise ValueError(f"Index at position {i} is not monotonically increasing") _reindex_kwargs = resolve_dict(reindex_kwargs, i=i) was_bool = (isinstance(objs[i], pd.Series) and objs[i].dtype == "bool") or ( isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "bool").all() ) objs[i] = objs[i].reindex(new_index, **_reindex_kwargs) is_object = (isinstance(objs[i], pd.Series) and objs[i].dtype == "object") or ( isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "object").all() ) if was_bool and is_object: objs[i] = objs[i].astype(None) if align_columns: columns_to_align = [] for i in range(len(objs)): if axis is not None: if checks.is_sequence(axis): _axis = axis[i] else: _axis = axis else: _axis = None if _axis in (None, 1): if checks.is_frame(objs[i]) and len(objs[i].columns) > 1: if not checks.is_default_index(objs[i].columns): columns_to_align.append(i) if (len(columns_to_align) > 0 and to_columns is not None) or len(columns_to_align) > 1: indexes_ = [objs[i].columns for i in columns_to_align] if to_columns is not None: indexes_.append(to_columns) if len(set(map(len, indexes_))) > 1: col_indices = indexes.align_indexes(*indexes_) for i in columns_to_align: objs[i] = objs[i].iloc[:, col_indices[columns_to_align.index(i)]] if len(objs) == 1: return objs[0] return tuple(objs) @define class BCO(DefineMixin): """Class that represents an object passed to `broadcast`. If any value is None, mostly defaults to the global value passed to `broadcast`.""" value: tp.Any = define.field() """Value of the object.""" axis: tp.Optional[int] = define.field(default=None) """Axis to broadcast. Set to None to broadcast all axes.""" to_pd: tp.Optional[bool] = define.field(default=None) """Whether to convert the output array to a Pandas object.""" keep_flex: tp.Optional[bool] = define.field(default=None) """Whether to keep the raw version of the output for flexible indexing. Only makes sure that the array can broadcast to the target shape.""" min_ndim: tp.Optional[int] = define.field(default=None) """Minimum number of dimensions.""" expand_axis: tp.Optional[int] = define.field(default=None) """Axis to expand if the array is 1-dim but the target shape is 2-dim.""" post_func: tp.Optional[tp.Callable] = define.field(default=None) """Function to post-process the output array.""" require_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None) """Keyword arguments passed to `np.require`.""" reindex_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None) """Keyword arguments passed to `pd.DataFrame.reindex`.""" merge_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None) """Keyword arguments passed to `vectorbtpro.base.merging.column_stack_merge`.""" context: tp.KwargsLike = define.field(default=None) """Context used in evaluation of templates. Will be merged over `template_context`.""" @define class Default(DefineMixin): """Class for wrapping default values.""" value: tp.Any = define.field() """Default value.""" @define class Ref(DefineMixin): """Class for wrapping references to other values.""" key: tp.Hashable = define.field() """Reference to another key.""" def resolve_ref(dct: dict, k: tp.Hashable, inside_bco: bool = False, keep_wrap_default: bool = False) -> tp.Any: """Resolve a potential reference.""" v = dct[k] is_default = False if isinstance(v, Default): v = v.value is_default = True if isinstance(v, Ref): new_v = resolve_ref(dct, v.key, inside_bco=inside_bco) if keep_wrap_default and is_default: return Default(new_v) return new_v if isinstance(v, BCO) and inside_bco: v = v.value is_default = False if isinstance(v, Default): v = v.value is_default = True if isinstance(v, Ref): new_v = resolve_ref(dct, v.key, inside_bco=inside_bco) if keep_wrap_default and is_default: return Default(new_v) return new_v return v def broadcast( *objs, to_shape: tp.Optional[tp.ShapeLike] = None, align_index: tp.Optional[bool] = None, align_columns: tp.Optional[bool] = None, index_from: tp.Optional[IndexFromLike] = None, columns_from: tp.Optional[IndexFromLike] = None, to_frame: tp.Optional[bool] = None, axis: tp.Optional[tp.MaybeMappingSequence[int]] = None, to_pd: tp.Optional[tp.MaybeMappingSequence[bool]] = None, keep_flex: tp.MaybeMappingSequence[tp.Optional[bool]] = None, min_ndim: tp.MaybeMappingSequence[tp.Optional[int]] = None, expand_axis: tp.MaybeMappingSequence[tp.Optional[int]] = None, post_func: tp.MaybeMappingSequence[tp.Optional[tp.Callable]] = None, require_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None, reindex_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None, merge_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None, tile: tp.Union[None, int, tp.IndexLike] = None, random_subset: tp.Optional[int] = None, seed: tp.Optional[int] = None, keep_wrap_default: tp.Optional[bool] = None, return_wrapper: bool = False, wrapper_kwargs: tp.KwargsLike = None, ignore_sr_names: tp.Optional[bool] = None, ignore_ranges: tp.Optional[bool] = None, check_index_names: tp.Optional[bool] = None, clean_index_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, ) -> tp.Any: """Bring any array-like object in `objs` to the same shape by using NumPy-like broadcasting. See [Broadcasting](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). !!! important The major difference to NumPy is that one-dimensional arrays will always broadcast against the row axis! Can broadcast Pandas objects by broadcasting their index/columns with `broadcast_index`. Args: *objs: Objects to broadcast. If the first and only argument is a mapping, will return a dict. Allows using `BCO`, `Ref`, `Default`, `vectorbtpro.utils.params.Param`, `vectorbtpro.base.indexing.index_dict`, `vectorbtpro.base.indexing.IdxSetter`, `vectorbtpro.base.indexing.IdxSetterFactory`, and templates. If an index dictionary, fills using `vectorbtpro.base.wrapping.ArrayWrapper.fill_and_set`. to_shape (tuple of int): Target shape. If set, will broadcast every object in `objs` to `to_shape`. align_index (bool): Whether to align index of Pandas objects using union. Pass None to use the default. align_columns (bool): Whether to align columns of Pandas objects using multi-index. Pass None to use the default. index_from (any): Broadcasting rule for index. Pass None to use the default. columns_from (any): Broadcasting rule for columns. Pass None to use the default. to_frame (bool): Whether to convert all Series to DataFrames. axis (int, sequence or mapping): See `BCO.axis`. to_pd (bool, sequence or mapping): See `BCO.to_pd`. If None, converts only if there is at least one Pandas object among them. keep_flex (bool, sequence or mapping): See `BCO.keep_flex`. min_ndim (int, sequence or mapping): See `BCO.min_ndim`. If None, becomes 2 if `keep_flex` is True, otherwise 1. expand_axis (int, sequence or mapping): See `BCO.expand_axis`. post_func (callable, sequence or mapping): See `BCO.post_func`. Applied only when `keep_flex` is False. require_kwargs (dict, sequence or mapping): See `BCO.require_kwargs`. This key will be merged with any argument-specific dict. If the mapping contains all keys in `np.require`, it will be applied to all objects. reindex_kwargs (dict, sequence or mapping): See `BCO.reindex_kwargs`. This key will be merged with any argument-specific dict. If the mapping contains all keys in `pd.DataFrame.reindex`, it will be applied to all objects. merge_kwargs (dict, sequence or mapping): See `BCO.merge_kwargs`. This key will be merged with any argument-specific dict. If the mapping contains all keys in `pd.DataFrame.merge`, it will be applied to all objects. tile (int or index_like): Tile the final object by the number of times or index. random_subset (int): Select a random subset of parameter values. Seed can be set using NumPy before calling this function. seed (int): Seed to make output deterministic. keep_wrap_default (bool): Whether to keep wrapping with `vectorbtpro.base.reshaping.Default`. return_wrapper (bool): Whether to also return the wrapper associated with the operation. wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`. ignore_sr_names (bool): See `broadcast_index`. ignore_ranges (bool): See `broadcast_index`. check_index_names (bool): See `broadcast_index`. clean_index_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`. template_context (dict): Context used to substitute templates. For defaults, see `vectorbtpro._settings.broadcasting`. Any keyword argument that can be associated with an object can be passed as * a const that is applied to all objects, * a sequence with value per object, and * a mapping with value per object name and the special key `_def` denoting the default value. Additionally, any object can be passed wrapped with `BCO`, which ibutes will override any of the above arguments if not None. Usage: * Without broadcasting index and columns: ```pycon >>> from vectorbtpro import * >>> v = 0 >>> a = np.array([1, 2, 3]) >>> sr = pd.Series([1, 2, 3], index=pd.Index(['x', 'y', 'z']), name='a') >>> df = pd.DataFrame( ... [[1, 2, 3], [4, 5, 6], [7, 8, 9]], ... index=pd.Index(['x2', 'y2', 'z2']), ... columns=pd.Index(['a2', 'b2', 'c2']), ... ) >>> for i in vbt.broadcast( ... v, a, sr, df, ... index_from='keep', ... columns_from='keep', ... align_index=False ... ): print(i) 0 1 2 0 0 0 0 1 0 0 0 2 0 0 0 0 1 2 0 1 2 3 1 1 2 3 2 1 2 3 a a a x 1 1 1 y 2 2 2 z 3 3 3 a2 b2 c2 x2 1 2 3 y2 4 5 6 z2 7 8 9 ``` * Take index and columns from the argument at specific position: ```pycon >>> for i in vbt.broadcast( ... v, a, sr, df, ... index_from=2, ... columns_from=3, ... align_index=False ... ): print(i) a2 b2 c2 x 0 0 0 y 0 0 0 z 0 0 0 a2 b2 c2 x 1 2 3 y 1 2 3 z 1 2 3 a2 b2 c2 x 1 1 1 y 2 2 2 z 3 3 3 a2 b2 c2 x 1 2 3 y 4 5 6 z 7 8 9 ``` * Broadcast index and columns through stacking: ```pycon >>> for i in vbt.broadcast( ... v, a, sr, df, ... index_from='stack', ... columns_from='stack', ... align_index=False ... ): print(i) a2 b2 c2 x x2 0 0 0 y y2 0 0 0 z z2 0 0 0 a2 b2 c2 x x2 1 2 3 y y2 1 2 3 z z2 1 2 3 a2 b2 c2 x x2 1 1 1 y y2 2 2 2 z z2 3 3 3 a2 b2 c2 x x2 1 2 3 y y2 4 5 6 z z2 7 8 9 ``` * Set index and columns manually: ```pycon >>> for i in vbt.broadcast( ... v, a, sr, df, ... index_from=['a', 'b', 'c'], ... columns_from=['d', 'e', 'f'], ... align_index=False ... ): print(i) d e f a 0 0 0 b 0 0 0 c 0 0 0 d e f a 1 2 3 b 1 2 3 c 1 2 3 d e f a 1 1 1 b 2 2 2 c 3 3 3 d e f a 1 2 3 b 4 5 6 c 7 8 9 ``` * Pass arguments as a mapping returns a mapping: ```pycon >>> vbt.broadcast( ... dict(v=v, a=a, sr=sr, df=df), ... index_from='stack', ... align_index=False ... ) {'v': a2 b2 c2 x x2 0 0 0 y y2 0 0 0 z z2 0 0 0, 'a': a2 b2 c2 x x2 1 2 3 y y2 1 2 3 z z2 1 2 3, 'sr': a2 b2 c2 x x2 1 1 1 y y2 2 2 2 z z2 3 3 3, 'df': a2 b2 c2 x x2 1 2 3 y y2 4 5 6 z z2 7 8 9} ``` * Keep all results in a format suitable for flexible indexing apart from one: ```pycon >>> vbt.broadcast( ... dict(v=v, a=a, sr=sr, df=df), ... index_from='stack', ... keep_flex=dict(_def=True, df=False), ... require_kwargs=dict(df=dict(dtype=float)), ... align_index=False ... ) {'v': array([0]), 'a': array([1, 2, 3]), 'sr': array([[1], [2], [3]]), 'df': a2 b2 c2 x x2 1.0 2.0 3.0 y y2 4.0 5.0 6.0 z z2 7.0 8.0 9.0} ``` * Specify arguments per object using `BCO`: ```pycon >>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float)) >>> vbt.broadcast( ... dict(v=v, a=a, sr=sr, df=df_bco), ... index_from='stack', ... keep_flex=True, ... align_index=False ... ) {'v': array([0]), 'a': array([1, 2, 3]), 'sr': array([[1], [2], [3]]), 'df': a2 b2 c2 x x2 1.0 2.0 3.0 y y2 4.0 5.0 6.0 z z2 7.0 8.0 9.0} ``` * Introduce a parameter that should build a Cartesian product of its values and other objects: ```pycon >>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float)) >>> p_bco = vbt.BCO(pd.Param([1, 2, 3], name='my_p')) >>> vbt.broadcast( ... dict(v=v, a=a, sr=sr, df=df_bco, p=p_bco), ... index_from='stack', ... keep_flex=True, ... align_index=False ... ) {'v': array([0]), 'a': array([1, 2, 3, 1, 2, 3, 1, 2, 3]), 'sr': array([[1], [2], [3]]), 'df': my_p 1 2 3 a2 b2 c2 a2 b2 c2 a2 b2 c2 x x2 1.0 2.0 3.0 1.0 2.0 3.0 1.0 2.0 3.0 y y2 4.0 5.0 6.0 4.0 5.0 6.0 4.0 5.0 6.0 z z2 7.0 8.0 9.0 7.0 8.0 9.0 7.0 8.0 9.0, 'p': array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 1, 1, 2, 2, 2, 3, 3, 3]])} ``` * Build a Cartesian product of all parameters: ```pycon >>> vbt.broadcast( ... dict( ... a=vbt.Param([1, 2, 3]), ... b=vbt.Param(['x', 'y']), ... c=vbt.Param([False, True]) ... ) ... ) {'a': array([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]]), 'b': array([['x', 'x', 'y', 'y', 'x', 'x', 'y', 'y', 'x', 'x', 'y', 'y']], dtype='>> vbt.broadcast( ... dict( ... a=vbt.Param([1, 2, 3], level=0), ... b=vbt.Param(['x', 'y'], level=1), ... d=vbt.Param([100., 200., 300.], level=0), ... c=vbt.Param([False, True], level=1) ... ) ... ) {'a': array([[1, 1, 2, 2, 3, 3]]), 'b': array([['x', 'y', 'x', 'y', 'x', 'y']], dtype='>> vbt.broadcast( ... dict( ... a=vbt.Param([1, 2, 3]), ... b=vbt.Param(['x', 'y']), ... c=vbt.Param([False, True]) ... ), ... random_subset=5, ... seed=42 ... ) {'a': array([[1, 2, 3, 3, 3]]), 'b': array([['x', 'x', 'x', 'x', 'y']], dtype=' 1: raise ValueError("Only one argument is allowed when passing a mapping") all_keys = list(dict(objs[0]).keys()) objs = list(objs[0].values()) return_dict = True else: objs = list(objs) all_keys = list(range(len(objs))) return_dict = False def _resolve_arg(obj: tp.Any, arg_name: str, global_value: tp.Any, default_value: tp.Any) -> tp.Any: if isinstance(obj, BCO) and getattr(obj, arg_name) is not None: return getattr(obj, arg_name) if checks.is_mapping(global_value): return global_value.get(k, global_value.get("_def", default_value)) if checks.is_sequence(global_value): return global_value[i] return global_value # Build BCO instances none_keys = set() default_keys = set() param_keys = set() special_keys = set() bco_instances = {} pool = dict(zip(all_keys, objs)) for i, k in enumerate(all_keys): obj = objs[i] if isinstance(obj, Default): obj = obj.value default_keys.add(k) if isinstance(obj, Ref): obj = resolve_ref(pool, k) if isinstance(obj, BCO): value = obj.value else: value = obj if isinstance(value, Default): value = value.value default_keys.add(k) if isinstance(value, Ref): value = resolve_ref(pool, k, inside_bco=True) if value is None: none_keys.add(k) continue _axis = _resolve_arg(obj, "axis", axis, None) _to_pd = _resolve_arg(obj, "to_pd", to_pd, None) _keep_flex = _resolve_arg(obj, "keep_flex", keep_flex, None) if _keep_flex is None: _keep_flex = broadcasting_cfg["keep_flex"] _min_ndim = _resolve_arg(obj, "min_ndim", min_ndim, None) if _min_ndim is None: _min_ndim = broadcasting_cfg["min_ndim"] _expand_axis = _resolve_arg(obj, "expand_axis", expand_axis, None) if _expand_axis is None: _expand_axis = broadcasting_cfg["expand_axis"] _post_func = _resolve_arg(obj, "post_func", post_func, None) if isinstance(obj, BCO) and obj.require_kwargs is not None: _require_kwargs = obj.require_kwargs else: _require_kwargs = None if checks.is_mapping(require_kwargs) and require_kwargs_per_obj: _require_kwargs = merge_dicts( require_kwargs.get("_def", None), require_kwargs.get(k, None), _require_kwargs, ) elif checks.is_sequence(require_kwargs) and require_kwargs_per_obj: _require_kwargs = merge_dicts(require_kwargs[i], _require_kwargs) else: _require_kwargs = merge_dicts(require_kwargs, _require_kwargs) if isinstance(obj, BCO) and obj.reindex_kwargs is not None: _reindex_kwargs = obj.reindex_kwargs else: _reindex_kwargs = None if checks.is_mapping(reindex_kwargs) and reindex_kwargs_per_obj: _reindex_kwargs = merge_dicts( reindex_kwargs.get("_def", None), reindex_kwargs.get(k, None), _reindex_kwargs, ) elif checks.is_sequence(reindex_kwargs) and reindex_kwargs_per_obj: _reindex_kwargs = merge_dicts(reindex_kwargs[i], _reindex_kwargs) else: _reindex_kwargs = merge_dicts(reindex_kwargs, _reindex_kwargs) if isinstance(obj, BCO) and obj.merge_kwargs is not None: _merge_kwargs = obj.merge_kwargs else: _merge_kwargs = None if checks.is_mapping(merge_kwargs) and merge_kwargs_per_obj: _merge_kwargs = merge_dicts( merge_kwargs.get("_def", None), merge_kwargs.get(k, None), _merge_kwargs, ) elif checks.is_sequence(merge_kwargs) and merge_kwargs_per_obj: _merge_kwargs = merge_dicts(merge_kwargs[i], _merge_kwargs) else: _merge_kwargs = merge_dicts(merge_kwargs, _merge_kwargs) if isinstance(obj, BCO): _context = merge_dicts(template_context, obj.context) else: _context = template_context if isinstance(value, Param): param_keys.add(k) elif isinstance(value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory, CustomTemplate)): special_keys.add(k) else: value = to_any_array(value) bco_instances[k] = BCO( value, axis=_axis, to_pd=_to_pd, keep_flex=_keep_flex, min_ndim=_min_ndim, expand_axis=_expand_axis, post_func=_post_func, require_kwargs=_require_kwargs, reindex_kwargs=_reindex_kwargs, merge_kwargs=_merge_kwargs, context=_context, ) # Check whether we should broadcast Pandas metadata and work on 2-dim data is_pd = False is_2d = False old_objs = {} obj_axis = {} obj_reindex_kwargs = {} for k, bco_obj in bco_instances.items(): if k in none_keys or k in param_keys or k in special_keys: continue obj = bco_obj.value if obj.ndim > 1: is_2d = True if checks.is_pandas(obj): is_pd = True if bco_obj.to_pd is not None and bco_obj.to_pd: is_pd = True old_objs[k] = obj obj_axis[k] = bco_obj.axis obj_reindex_kwargs[k] = bco_obj.reindex_kwargs if to_shape is not None: if isinstance(to_shape, int): to_shape = (to_shape,) if len(to_shape) > 1: is_2d = True if to_frame is not None: is_2d = to_frame if to_pd is not None: is_pd = to_pd or (return_wrapper and is_pd) # Align pandas arrays if index_from is not None and not isinstance(index_from, (int, str, pd.Index)): index_from = pd.Index(index_from) if columns_from is not None and not isinstance(columns_from, (int, str, pd.Index)): columns_from = pd.Index(columns_from) aligned_objs = align_pd_arrays( *old_objs.values(), align_index=align_index, align_columns=align_columns, to_index=index_from if isinstance(index_from, pd.Index) else None, to_columns=columns_from if isinstance(columns_from, pd.Index) else None, axis=list(obj_axis.values()), reindex_kwargs=list(obj_reindex_kwargs.values()), ) if not isinstance(aligned_objs, tuple): aligned_objs = (aligned_objs,) aligned_objs = dict(zip(old_objs.keys(), aligned_objs)) # Convert to NumPy ready_objs = {} for k, obj in aligned_objs.items(): _expand_axis = bco_instances[k].expand_axis new_obj = np.asarray(obj) if is_2d and new_obj.ndim == 1: if isinstance(obj, pd.Series): new_obj = new_obj[:, None] else: new_obj = np.expand_dims(new_obj, _expand_axis) ready_objs[k] = new_obj # Get final shape if to_shape is None: try: to_shape = broadcast_shapes( *map(lambda x: x.shape, ready_objs.values()), axis=list(obj_axis.values()), ) except ValueError: arr_shapes = {} for i, k in enumerate(bco_instances): if k in none_keys or k in param_keys or k in special_keys: continue if len(ready_objs[k].shape) > 0: arr_shapes[k] = ready_objs[k].shape raise ValueError("Could not broadcast shapes: %s" % str(arr_shapes)) if not isinstance(to_shape, tuple): to_shape = (to_shape,) if len(to_shape) == 0: to_shape = (1,) to_shape_2d = to_shape if len(to_shape) > 1 else (*to_shape, 1) if is_pd: # Decide on index and columns # NOTE: Important to pass aligned_objs, not ready_objs, to preserve original shape info new_index = broadcast_index( [v for k, v in aligned_objs.items() if obj_axis[k] in (None, 0)], to_shape, index_from=index_from, axis=0, ignore_sr_names=ignore_sr_names, ignore_ranges=ignore_ranges, check_index_names=check_index_names, **clean_index_kwargs, ) new_columns = broadcast_index( [v for k, v in aligned_objs.items() if obj_axis[k] in (None, 1)], to_shape, index_from=columns_from, axis=1, ignore_sr_names=ignore_sr_names, ignore_ranges=ignore_ranges, check_index_names=check_index_names, **clean_index_kwargs, ) else: new_index = pd.RangeIndex(stop=to_shape_2d[0]) new_columns = pd.RangeIndex(stop=to_shape_2d[1]) # Build a product param_product = None param_columns = None n_params = 0 if len(param_keys) > 0: # Combine parameters param_dct = {} for k, bco_obj in bco_instances.items(): if k not in param_keys: continue param_dct[k] = bco_obj.value param_product, param_columns = combine_params( param_dct, random_subset=random_subset, seed=seed, clean_index_kwargs=clean_index_kwargs, ) n_params = len(param_columns) # Combine parameter columns with new columns if param_columns is not None and new_columns is not None: new_columns = indexes.combine_indexes([param_columns, new_columns], **clean_index_kwargs) # Tile if tile is not None: if isinstance(tile, int): if new_columns is not None: new_columns = indexes.tile_index(new_columns, tile) else: if new_columns is not None: new_columns = indexes.combine_indexes([tile, new_columns], **clean_index_kwargs) tile = len(tile) n_params = max(n_params, 1) * tile # Build wrapper if n_params == 0: new_shape = to_shape else: new_shape = (to_shape_2d[0], to_shape_2d[1] * n_params) wrapper = wrapping.ArrayWrapper.from_shape( new_shape, **merge_dicts( dict( index=new_index, columns=new_columns, ), wrapper_kwargs, ), ) def _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis): if _min_ndim is None: if _keep_flex: _min_ndim = 2 else: _min_ndim = 1 if _min_ndim not in (1, 2): raise ValueError("Argument min_ndim must be either 1 or 2") if _min_ndim in (1, 2) and new_obj.ndim == 0: new_obj = new_obj[None] if _min_ndim == 2 and new_obj.ndim == 1: if len(to_shape) == 1: new_obj = new_obj[:, None] else: new_obj = np.expand_dims(new_obj, _expand_axis) return new_obj # Perform broadcasting aligned_objs2 = {} new_objs = {} for i, k in enumerate(all_keys): if k in none_keys or k in special_keys: continue _keep_flex = bco_instances[k].keep_flex _min_ndim = bco_instances[k].min_ndim _axis = bco_instances[k].axis _expand_axis = bco_instances[k].expand_axis _merge_kwargs = bco_instances[k].merge_kwargs _context = bco_instances[k].context must_reset_index = _merge_kwargs.get("reset_index", None) not in (None, False) _reindex_kwargs = resolve_dict(bco_instances[k].reindex_kwargs) _fill_value = _reindex_kwargs.get("fill_value", np.nan) if k in param_keys: # Broadcast parameters from vectorbtpro.base.merging import column_stack_merge if _axis == 0: raise ValueError("Parameters do not support broadcasting with axis=0") obj = param_product[k] new_obj = [] any_needs_broadcasting = False all_forced_broadcast = True for o in obj: if isinstance(o, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)): o = wrapper.fill_and_set( o, fill_value=_fill_value, keep_flex=_keep_flex, ) elif isinstance(o, CustomTemplate): context = merge_dicts( dict( bco_instances=bco_instances, wrapper=wrapper, obj_name=k, bco=bco_instances[k], ), _context, ) o = o.substitute(context, eval_id="broadcast") o = to_2d_array(o) if not _keep_flex: needs_broadcasting = True elif o.shape[0] > 1: needs_broadcasting = True elif o.shape[1] > 1 and o.shape[1] != to_shape_2d[1]: needs_broadcasting = True else: needs_broadcasting = False if needs_broadcasting: any_needs_broadcasting = True o = broadcast_array_to(o, to_shape_2d, axis=_axis) elif o.size == 1: all_forced_broadcast = False o = np.repeat(o, to_shape_2d[1], axis=1) else: all_forced_broadcast = False new_obj.append(o) if any_needs_broadcasting and not all_forced_broadcast: new_obj2 = [] for o in new_obj: if o.shape[1] != to_shape_2d[1] or (not must_reset_index and o.shape[0] != to_shape_2d[0]): o = broadcast_array_to(o, to_shape_2d, axis=_axis) new_obj2.append(o) new_obj = new_obj2 obj = column_stack_merge(new_obj, **_merge_kwargs) if tile is not None: obj = np.tile(obj, (1, tile)) old_obj = obj new_obj = obj else: # Broadcast regular objects old_obj = aligned_objs[k] new_obj = ready_objs[k] if _axis in (None, 0) and new_obj.ndim >= 1 and new_obj.shape[0] > 1 and new_obj.shape[0] != to_shape[0]: raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}") if _axis in (None, 1) and new_obj.ndim == 2 and new_obj.shape[1] > 1 and new_obj.shape[1] != to_shape[1]: raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}") if _keep_flex: if n_params > 0 and _axis in (None, 1): if len(to_shape) == 1: if new_obj.ndim == 1 and new_obj.shape[0] > 1: new_obj = new_obj[:, None] # product changes is_2d behavior else: if new_obj.ndim == 1 and new_obj.shape[0] > 1: new_obj = np.tile(new_obj, n_params) elif new_obj.ndim == 2 and new_obj.shape[1] > 1: new_obj = np.tile(new_obj, (1, n_params)) else: new_obj = broadcast_array_to(new_obj, to_shape, axis=_axis) if n_params > 0 and _axis in (None, 1): if new_obj.ndim == 1: new_obj = new_obj[:, None] # product changes is_2d behavior new_obj = np.tile(new_obj, (1, n_params)) new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis) aligned_objs2[k] = old_obj new_objs[k] = new_obj # Resolve special objects new_objs2 = {} for i, k in enumerate(all_keys): if k in none_keys: continue if k in special_keys: bco = bco_instances[k] if isinstance(bco.value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)): _is_pd = bco.to_pd if _is_pd is None: _is_pd = is_pd _keep_flex = bco.keep_flex _min_ndim = bco.min_ndim _expand_axis = bco.expand_axis _reindex_kwargs = resolve_dict(bco.reindex_kwargs) _fill_value = _reindex_kwargs.get("fill_value", np.nan) new_obj = wrapper.fill_and_set( bco.value, fill_value=_fill_value, keep_flex=_keep_flex, ) if not _is_pd and not _keep_flex: new_obj = new_obj.values new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis) elif isinstance(bco.value, CustomTemplate): context = merge_dicts( dict( bco_instances=bco_instances, new_objs=new_objs, wrapper=wrapper, obj_name=k, bco=bco, ), bco.context, ) new_obj = bco.value.substitute(context, eval_id="broadcast") else: raise TypeError(f"Special type {type(bco.value)} is not supported") else: new_obj = new_objs[k] # Force to match requirements new_obj = np.require(new_obj, **resolve_dict(bco_instances[k].require_kwargs)) new_objs2[k] = new_obj # Perform wrapping and post-processing new_objs3 = {} for i, k in enumerate(all_keys): if k in none_keys: continue new_obj = new_objs2[k] _axis = bco_instances[k].axis _keep_flex = bco_instances[k].keep_flex if not _keep_flex: # Wrap array _is_pd = bco_instances[k].to_pd if _is_pd is None: _is_pd = is_pd new_obj = wrap_broadcasted( new_obj, old_obj=aligned_objs2[k] if k not in special_keys else None, axis=_axis, is_pd=_is_pd, new_index=new_index, new_columns=new_columns, ignore_ranges=ignore_ranges, ) # Post-process array _post_func = bco_instances[k].post_func if _post_func is not None: new_obj = _post_func(new_obj) new_objs3[k] = new_obj # Prepare outputs return_objs = [] for k in all_keys: if k not in none_keys: if k in default_keys and keep_wrap_default: return_objs.append(Default(new_objs3[k])) else: return_objs.append(new_objs3[k]) else: if k in default_keys and keep_wrap_default: return_objs.append(Default(None)) else: return_objs.append(None) if return_dict: return_objs = dict(zip(all_keys, return_objs)) else: return_objs = tuple(return_objs) if len(return_objs) > 1 or return_dict: if return_wrapper: return return_objs, wrapper return return_objs if return_wrapper: return return_objs[0], wrapper return return_objs[0] def broadcast_to( arg1: tp.ArrayLike, arg2: tp.Union[tp.ArrayLike, tp.ShapeLike, wrapping.ArrayWrapper], to_pd: tp.Optional[bool] = None, index_from: tp.Optional[IndexFromLike] = None, columns_from: tp.Optional[IndexFromLike] = None, **kwargs, ) -> tp.Any: """Broadcast `arg1` to `arg2`. Argument `arg2` can be a shape, an instance of `vectorbtpro.base.wrapping.ArrayWrapper`, or any array-like object. Pass None to `index_from`/`columns_from` to use index/columns of the second argument. Keyword arguments `**kwargs` are passed to `broadcast`. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.reshaping import broadcast_to >>> a = np.array([1, 2, 3]) >>> sr = pd.Series([4, 5, 6], index=pd.Index(['x', 'y', 'z']), name='a') >>> broadcast_to(a, sr) x 1 y 2 z 3 Name: a, dtype: int64 >>> broadcast_to(sr, a) array([4, 5, 6]) ``` """ if checks.is_int(arg2) or isinstance(arg2, tuple): arg2 = to_tuple_shape(arg2) if isinstance(arg2, tuple): to_shape = arg2 elif isinstance(arg2, wrapping.ArrayWrapper): to_pd = True if index_from is None: index_from = arg2.index if columns_from is None: columns_from = arg2.columns to_shape = arg2.shape else: arg2 = to_any_array(arg2) if to_pd is None: to_pd = checks.is_pandas(arg2) if to_pd: # Take index and columns from arg2 if index_from is None: index_from = indexes.get_index(arg2, 0) if columns_from is None: columns_from = indexes.get_index(arg2, 1) to_shape = arg2.shape return broadcast( arg1, to_shape=to_shape, to_pd=to_pd, index_from=index_from, columns_from=columns_from, **kwargs, ) def broadcast_to_array_of(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> tp.Array: """Broadcast `arg1` to the shape `(1, *arg2.shape)`. `arg1` must be either a scalar, a 1-dim array, or have 1 dimension more than `arg2`. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.reshaping import broadcast_to_array_of >>> broadcast_to_array_of([0.1, 0.2], np.empty((2, 2))) [[[0.1 0.1] [0.1 0.1]] [[0.2 0.2] [0.2 0.2]]] ``` """ arg1 = np.asarray(arg1) arg2 = np.asarray(arg2) if arg1.ndim == arg2.ndim + 1: if arg1.shape[1:] == arg2.shape: return arg1 # From here on arg1 can be only a 1-dim array if arg1.ndim == 0: arg1 = to_1d(arg1) checks.assert_ndim(arg1, 1) if arg2.ndim == 0: return arg1 for i in range(arg2.ndim): arg1 = np.expand_dims(arg1, axis=-1) return np.tile(arg1, (1, *arg2.shape)) def broadcast_to_axis_of( arg1: tp.ArrayLike, arg2: tp.ArrayLike, axis: int, require_kwargs: tp.KwargsLike = None, ) -> tp.Array: """Broadcast `arg1` to an axis of `arg2`. If `arg2` has less dimensions than requested, will broadcast `arg1` to a single number. For other keyword arguments, see `broadcast`.""" if require_kwargs is None: require_kwargs = {} arg2 = to_any_array(arg2) if arg2.ndim < axis + 1: return broadcast_array_to(arg1, (1,))[0] # to a single number arg1 = broadcast_array_to(arg1, (arg2.shape[axis],)) arg1 = np.require(arg1, **require_kwargs) return arg1 def broadcast_combs( *objs: tp.ArrayLike, axis: int = 1, comb_func: tp.Callable = itertools.product, **broadcast_kwargs, ) -> tp.Any: """Align an axis of each array using a combinatoric function and broadcast their indexes. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.reshaping import broadcast_combs >>> df = pd.DataFrame([[1, 2, 3], [3, 4, 5]], columns=pd.Index(['a', 'b', 'c'], name='df_param')) >>> df2 = pd.DataFrame([[6, 7], [8, 9]], columns=pd.Index(['d', 'e'], name='df2_param')) >>> sr = pd.Series([10, 11], name='f') >>> new_df, new_df2, new_sr = broadcast_combs((df, df2, sr)) >>> new_df df_param a b c df2_param d e d e d e 0 1 1 2 2 3 3 1 3 3 4 4 5 5 >>> new_df2 df_param a b c df2_param d e d e d e 0 6 7 6 7 6 7 1 8 9 8 9 8 9 >>> new_sr df_param a b c df2_param d e d e d e 0 10 10 10 10 10 10 1 11 11 11 11 11 11 ``` """ if broadcast_kwargs is None: broadcast_kwargs = {} objs = list(objs) if len(objs) < 2: raise ValueError("At least two arguments are required") for i in range(len(objs)): obj = to_any_array(objs[i]) if axis == 1: obj = to_2d(obj) objs[i] = obj indices = [] for obj in objs: indices.append(np.arange(len(indexes.get_index(to_pd_array(obj), axis)))) new_indices = list(map(list, zip(*list(comb_func(*indices))))) results = [] for i, obj in enumerate(objs): if axis == 1: if checks.is_pandas(obj): results.append(obj.iloc[:, new_indices[i]]) else: results.append(obj[:, new_indices[i]]) else: if checks.is_pandas(obj): results.append(obj.iloc[new_indices[i]]) else: results.append(obj[new_indices[i]]) if axis == 1: broadcast_kwargs = merge_dicts(dict(columns_from="stack"), broadcast_kwargs) else: broadcast_kwargs = merge_dicts(dict(index_from="stack"), broadcast_kwargs) return broadcast(*results, **broadcast_kwargs) def get_multiindex_series(obj: tp.SeriesFrame) -> tp.Series: """Get Series with a multi-index. If DataFrame has been passed, must at maximum have one row or column.""" checks.assert_instance_of(obj, (pd.Series, pd.DataFrame)) if checks.is_frame(obj): if obj.shape[0] == 1: obj = obj.iloc[0, :] elif obj.shape[1] == 1: obj = obj.iloc[:, 0] else: raise ValueError("Supported are either Series or DataFrame with one column/row") checks.assert_instance_of(obj.index, pd.MultiIndex) return obj def unstack_to_array( obj: tp.SeriesFrame, levels: tp.Optional[tp.MaybeLevelSequence] = None, sort: bool = True, return_indexes: bool = False, ) -> tp.Union[tp.Array, tp.Tuple[tp.Array, tp.List[tp.Index]]]: """Reshape `obj` based on its multi-index into a multi-dimensional array. Use `levels` to specify what index levels to unstack and in which order. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.reshaping import unstack_to_array >>> index = pd.MultiIndex.from_arrays( ... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']]) >>> sr = pd.Series([1, 2, 3, 4], index=index) >>> unstack_to_array(sr).shape (2, 2, 4) >>> unstack_to_array(sr) [[[ 1. nan nan nan] [nan 2. nan nan]] [[nan nan 3. nan] [nan nan nan 4.]]] >>> unstack_to_array(sr, levels=(2, 0)) [[ 1. nan] [ 2. nan] [nan 3.] [nan 4.]] ``` """ sr = get_multiindex_series(obj) if sr.index.duplicated().any(): raise ValueError("Index contains duplicate entries, cannot reshape") new_index_list = [] value_indices_list = [] if levels is None: levels = range(sr.index.nlevels) if isinstance(levels, (int, str)): levels = (levels,) for level in levels: level_values = indexes.select_levels(sr.index, level) new_index = level_values.unique() if sort: new_index = new_index.sort_values() new_index_list.append(new_index) index_map = pd.Series(range(len(new_index)), index=new_index) value_indices = index_map.loc[level_values] value_indices_list.append(value_indices) a = np.full(list(map(len, new_index_list)), np.nan) a[tuple(zip(value_indices_list))] = sr.values if return_indexes: return a, new_index_list return a def make_symmetric(obj: tp.SeriesFrame, sort: bool = True) -> tp.Frame: """Make `obj` symmetric. The index and columns of the resulting DataFrame will be identical. Requires the index and columns to have the same number of levels. Pass `sort=False` if index and columns should not be sorted, but concatenated and get duplicates removed. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.reshaping import make_symmetric >>> df = pd.DataFrame([[1, 2], [3, 4]], index=['a', 'b'], columns=['c', 'd']) >>> make_symmetric(df) a b c d a NaN NaN 1.0 2.0 b NaN NaN 3.0 4.0 c 1.0 3.0 NaN NaN d 2.0 4.0 NaN NaN ``` """ from vectorbtpro.base.merging import concat_arrays checks.assert_instance_of(obj, (pd.Series, pd.DataFrame)) df = to_2d(obj) if isinstance(df.index, pd.MultiIndex) or isinstance(df.columns, pd.MultiIndex): checks.assert_instance_of(df.index, pd.MultiIndex) checks.assert_instance_of(df.columns, pd.MultiIndex) checks.assert_array_equal(df.index.nlevels, df.columns.nlevels) names1, names2 = tuple(df.index.names), tuple(df.columns.names) else: names1, names2 = df.index.name, df.columns.name if names1 == names2: new_name = names1 else: if isinstance(df.index, pd.MultiIndex): new_name = tuple(zip(*[names1, names2])) else: new_name = (names1, names2) if sort: idx_vals = np.unique(concat_arrays((df.index, df.columns))).tolist() else: idx_vals = list(dict.fromkeys(concat_arrays((df.index, df.columns)))) df_index = df.index.copy() df_columns = df.columns.copy() if isinstance(df.index, pd.MultiIndex): unique_index = pd.MultiIndex.from_tuples(idx_vals, names=new_name) df_index.names = new_name df_columns.names = new_name else: unique_index = pd.Index(idx_vals, name=new_name) df_index.name = new_name df_columns.name = new_name df = df.copy(deep=False) df.index = df_index df.columns = df_columns df_out_dtype = np.promote_types(df.values.dtype, np.min_scalar_type(np.nan)) df_out = pd.DataFrame(index=unique_index, columns=unique_index, dtype=df_out_dtype) df_out.loc[:, :] = df df_out[df_out.isnull()] = df.transpose() return df_out def unstack_to_df( obj: tp.SeriesFrame, index_levels: tp.Optional[tp.MaybeLevelSequence] = None, column_levels: tp.Optional[tp.MaybeLevelSequence] = None, symmetric: bool = False, sort: bool = True, ) -> tp.Frame: """Reshape `obj` based on its multi-index into a DataFrame. Use `index_levels` to specify what index levels will form new index, and `column_levels` for new columns. Set `symmetric` to True to make DataFrame symmetric. Usage: ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.base.reshaping import unstack_to_df >>> index = pd.MultiIndex.from_arrays( ... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']], ... names=['x', 'y', 'z']) >>> sr = pd.Series([1, 2, 3, 4], index=index) >>> unstack_to_df(sr, index_levels=(0, 1), column_levels=2) z a b c d x y 1 3 1.0 NaN NaN NaN 1 4 NaN 2.0 NaN NaN 2 3 NaN NaN 3.0 NaN 2 4 NaN NaN NaN 4.0 ``` """ sr = get_multiindex_series(obj) if sr.index.nlevels > 2: if index_levels is None: raise ValueError("index_levels must be specified") if column_levels is None: raise ValueError("column_levels must be specified") else: if index_levels is None: index_levels = 0 if column_levels is None: column_levels = 1 unstacked, (new_index, new_columns) = unstack_to_array( sr, levels=(index_levels, column_levels), sort=sort, return_indexes=True, ) df = pd.DataFrame(unstacked, index=new_index, columns=new_columns) if symmetric: return make_symmetric(df, sort=sort) return df # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for wrapping NumPy arrays into Series/DataFrames.""" import numpy as np import pandas as pd from pandas.core.groupby import GroupBy as PandasGroupBy from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import indexes, reshaping from vectorbtpro.base.grouping.base import Grouper from vectorbtpro.base.indexes import stack_indexes, concat_indexes, IndexApplier from vectorbtpro.base.indexing import IndexingError, ExtPandasIndexer, index_dict, IdxSetter, IdxSetterFactory, IdxDict from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.array_ import is_range, cast_to_min_precision, cast_to_max_precision from vectorbtpro.utils.attr_ import AttrResolverMixin, AttrResolverMixinT from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer from vectorbtpro.utils.config import Configured, merge_dicts, resolve_dict from vectorbtpro.utils.decorators import hybrid_method, cached_method, cached_property from vectorbtpro.utils.execution import Task, execute from vectorbtpro.utils.params import ItemParamable from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.warnings_ import warn if tp.TYPE_CHECKING: from vectorbtpro.base.accessors import BaseIDXAccessor as BaseIDXAccessorT from vectorbtpro.generic.splitting.base import Splitter as SplitterT else: BaseIDXAccessorT = "BaseIDXAccessor" SplitterT = "Splitter" __all__ = [ "ArrayWrapper", "Wrapping", ] HasWrapperT = tp.TypeVar("HasWrapperT", bound="HasWrapper") class HasWrapper(ExtPandasIndexer, ItemParamable): """Abstract class that manages a wrapper.""" @property def unwrapped(self) -> tp.Any: """Unwrapped object.""" raise NotImplementedError @hybrid_method def should_wrap(cls_or_self) -> bool: """Whether to wrap where applicable.""" return True @property def wrapper(self) -> "ArrayWrapper": """Array wrapper of the type `ArrayWrapper`.""" raise NotImplementedError @property def column_only_select(self) -> bool: """Whether to perform indexing on columns only.""" raise NotImplementedError @property def range_only_select(self) -> bool: """Whether to perform indexing on rows using slices only.""" raise NotImplementedError @property def group_select(self) -> bool: """Whether to allow indexing on groups.""" raise NotImplementedError def regroup(self: HasWrapperT, group_by: tp.GroupByLike, **kwargs) -> HasWrapperT: """Regroup this instance.""" raise NotImplementedError def ungroup(self: HasWrapperT, **kwargs) -> HasWrapperT: """Ungroup this instance.""" return self.regroup(False, **kwargs) # ############# Selection ############# # def select_col( self: HasWrapperT, column: tp.Any = None, group_by: tp.GroupByLike = None, **kwargs, ) -> HasWrapperT: """Select one column/group. `column` can be a label-based position as well as an integer position (if label fails).""" _self = self.regroup(group_by, **kwargs) def _check_out_dim(out: HasWrapperT) -> HasWrapperT: if out.wrapper.get_ndim() == 2: if out.wrapper.get_shape_2d()[1] == 1: if out.column_only_select: return out.iloc[0] return out.iloc[:, 0] if _self.wrapper.grouper.is_grouped(): raise TypeError("Could not select one group: multiple groups returned") else: raise TypeError("Could not select one column: multiple columns returned") return out if column is None: if _self.wrapper.get_ndim() == 2 and _self.wrapper.get_shape_2d()[1] == 1: column = 0 if column is not None: if _self.wrapper.grouper.is_grouped(): if _self.wrapper.grouped_ndim == 1: raise TypeError("This instance already contains one group of data") if column not in _self.wrapper.get_columns(): if isinstance(column, int): if _self.column_only_select: return _check_out_dim(_self.iloc[column]) return _check_out_dim(_self.iloc[:, column]) raise KeyError(f"Group '{column}' not found") else: if _self.wrapper.ndim == 1: raise TypeError("This instance already contains one column of data") if column not in _self.wrapper.columns: if isinstance(column, int): if _self.column_only_select: return _check_out_dim(_self.iloc[column]) return _check_out_dim(_self.iloc[:, column]) raise KeyError(f"Column '{column}' not found") return _check_out_dim(_self[column]) if _self.wrapper.grouper.is_grouped(): if _self.wrapper.grouped_ndim == 1: return _self raise TypeError("Only one group is allowed. Use indexing or column argument.") if _self.wrapper.ndim == 1: return _self raise TypeError("Only one column is allowed. Use indexing or column argument.") @hybrid_method def select_col_from_obj( cls_or_self, obj: tp.Optional[tp.SeriesFrame], column: tp.Any = None, obj_ungrouped: bool = False, group_by: tp.GroupByLike = None, wrapper: tp.Optional["ArrayWrapper"] = None, **kwargs, ) -> tp.MaybeSeries: """Select one column/group from a Pandas object. `column` can be a label-based position as well as an integer position (if label fails).""" if obj is None: return None if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") _wrapper = wrapper.regroup(group_by, **kwargs) def _check_out_dim(out: tp.SeriesFrame, from_df: bool) -> tp.Series: bad_shape = False if from_df and isinstance(out, pd.DataFrame): if len(out.columns) == 1: return out.iloc[:, 0] bad_shape = True if not from_df and isinstance(out, pd.Series): if len(out) == 1: return out.iloc[0] bad_shape = True if bad_shape: if _wrapper.grouper.is_grouped(): raise TypeError("Could not select one group: multiple groups returned") else: raise TypeError("Could not select one column: multiple columns returned") return out if column is None: if _wrapper.get_ndim() == 2 and _wrapper.get_shape_2d()[1] == 1: column = 0 if column is not None: if _wrapper.grouper.is_grouped(): if _wrapper.grouped_ndim == 1: raise TypeError("This instance already contains one group of data") if obj_ungrouped: mask = _wrapper.grouper.group_by == column if not mask.any(): raise KeyError(f"Group '{column}' not found") if isinstance(obj, pd.DataFrame): return obj.loc[:, mask] return obj.loc[mask] else: if column not in _wrapper.get_columns(): if isinstance(column, int): if isinstance(obj, pd.DataFrame): return _check_out_dim(obj.iloc[:, column], True) return _check_out_dim(obj.iloc[column], False) raise KeyError(f"Group '{column}' not found") else: if _wrapper.ndim == 1: raise TypeError("This instance already contains one column of data") if column not in _wrapper.columns: if isinstance(column, int): if isinstance(obj, pd.DataFrame): return _check_out_dim(obj.iloc[:, column], True) return _check_out_dim(obj.iloc[column], False) raise KeyError(f"Column '{column}' not found") if isinstance(obj, pd.DataFrame): return _check_out_dim(obj[column], True) return _check_out_dim(obj[column], False) if not _wrapper.grouper.is_grouped(): if _wrapper.ndim == 1: return obj raise TypeError("Only one column is allowed. Use indexing or column argument.") if _wrapper.grouped_ndim == 1: return obj raise TypeError("Only one group is allowed. Use indexing or column argument.") # ############# Splitting ############# # def split( self, *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, wrap: tp.Optional[bool] = None, **kwargs, ) -> tp.Any: """Split this instance. Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_take`.""" from vectorbtpro.generic.splitting.base import Splitter if splitter_cls is None: splitter_cls = Splitter if wrap is None: wrap = self.should_wrap() wrapped_self = self if wrap else self.unwrapped return splitter_cls.split_and_take(self.wrapper.index, wrapped_self, *args, **kwargs) def split_apply( self, apply_func: tp.Union[str, tp.Callable], *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, wrap: tp.Optional[bool] = None, **kwargs, ) -> tp.Any: """Split this instance and apply a function to each split. Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`.""" from vectorbtpro.generic.splitting.base import Splitter, Takeable if isinstance(apply_func, str): apply_func = getattr(type(self), apply_func) if splitter_cls is None: splitter_cls = Splitter if wrap is None: wrap = self.should_wrap() wrapped_self = self if wrap else self.unwrapped return splitter_cls.split_and_apply(self.wrapper.index, apply_func, Takeable(wrapped_self), *args, **kwargs) # ############# Chunking ############# # def chunk( self: HasWrapperT, axis: tp.Optional[int] = None, min_size: tp.Optional[int] = None, n_chunks: tp.Union[None, int, str] = None, chunk_len: tp.Union[None, int, str] = None, chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None, select: bool = False, wrap: tp.Optional[bool] = None, return_chunk_meta: bool = False, ) -> tp.Iterator[tp.Union[HasWrapperT, tp.Tuple[ChunkMeta, HasWrapperT]]]: """Chunk this instance. If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise. For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`.""" if axis is None: axis = 0 if self.wrapper.ndim == 1 else 1 if self.wrapper.ndim == 1 and axis == 1: raise TypeError("Axis 1 is not supported for one dimension") checks.assert_in(axis, (0, 1)) size = self.wrapper.shape_2d[axis] if wrap is None: wrap = self.should_wrap() wrapped_self = self if wrap else self.unwrapped if chunk_meta is None: chunk_meta = iter_chunk_meta( size=size, min_size=min_size, n_chunks=n_chunks, chunk_len=chunk_len, ) for _chunk_meta in chunk_meta: if select: array_taker = ArraySelector(axis=axis) else: array_taker = ArraySlicer(axis=axis) if return_chunk_meta: yield _chunk_meta, array_taker.take(wrapped_self, _chunk_meta) else: yield array_taker.take(wrapped_self, _chunk_meta) def chunk_apply( self: HasWrapperT, apply_func: tp.Union[str, tp.Callable], *args, chunk_kwargs: tp.KwargsLike = None, execute_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MergeableResults: """Chunk this instance and apply a function to each chunk. If `apply_func` is a string, becomes the method name. For arguments related to chunking, see `Wrapping.chunk`.""" if isinstance(apply_func, str): apply_func = getattr(type(self), apply_func) if chunk_kwargs is None: chunk_arg_names = set(get_func_arg_names(self.chunk)) chunk_kwargs = {} for k in list(kwargs.keys()): if k in chunk_arg_names: chunk_kwargs[k] = kwargs.pop(k) if execute_kwargs is None: execute_kwargs = {} chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs) tasks = [] keys = [] for _chunk_meta, chunk in chunks: tasks.append(Task(apply_func, chunk, *args, **kwargs)) keys.append(get_chunk_meta_key(_chunk_meta)) keys = pd.Index(keys, name="chunk_indices") return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs) # ############# Iteration ############# # def get_item_keys(self, group_by: tp.GroupByLike = None) -> tp.Index: """Get keys for `Wrapping.items`.""" _self = self.regroup(group_by=group_by) if _self.group_select and _self.wrapper.grouper.is_grouped(): return _self.wrapper.get_columns() return _self.wrapper.columns def items( self, group_by: tp.GroupByLike = None, apply_group_by: bool = False, keep_2d: bool = False, key_as_index: bool = False, wrap: tp.Optional[bool] = None, ) -> tp.Items: """Iterate over columns or groups (if grouped and `Wrapping.group_select` is True). If `apply_group_by` is False, `group_by` becomes a grouping instruction for the iteration, not for the final object. In this case, will raise an error if the instance is grouped and that grouping must be changed.""" if wrap is None: wrap = self.should_wrap() def _resolve_v(self): return self if wrap else self.unwrapped if group_by is None or apply_group_by: _self = self.regroup(group_by=group_by) if _self.group_select and _self.wrapper.grouper.is_grouped(): columns = _self.wrapper.get_columns() ndim = _self.wrapper.get_ndim() else: columns = _self.wrapper.columns ndim = _self.wrapper.ndim if ndim == 1: if key_as_index: yield columns, _resolve_v(_self) else: yield columns[0], _resolve_v(_self) else: for i in range(len(columns)): if key_as_index: key = columns[[i]] else: key = columns[i] if _self.column_only_select: if keep_2d: yield key, _resolve_v(_self.iloc[i : i + 1]) else: yield key, _resolve_v(_self.iloc[i]) else: if keep_2d: yield key, _resolve_v(_self.iloc[:, i : i + 1]) else: yield key, _resolve_v(_self.iloc[:, i]) else: if self.group_select and self.wrapper.grouper.is_grouped(): raise ValueError("Cannot change grouping") wrapper = self.wrapper.regroup(group_by=group_by) if wrapper.get_ndim() == 1: if key_as_index: yield wrapper.get_columns(), _resolve_v(self) else: yield wrapper.get_columns()[0], _resolve_v(self) else: for group, group_idxs in wrapper.grouper.iter_groups(key_as_index=key_as_index): if self.column_only_select: if keep_2d or len(group_idxs) > 1: yield group, _resolve_v(self.iloc[group_idxs]) else: yield group, _resolve_v(self.iloc[group_idxs[0]]) else: if keep_2d or len(group_idxs) > 1: yield group, _resolve_v(self.iloc[:, group_idxs]) else: yield group, _resolve_v(self.iloc[:, group_idxs[0]]) ArrayWrapperT = tp.TypeVar("ArrayWrapperT", bound="ArrayWrapper") class ArrayWrapper(Configured, HasWrapper, IndexApplier): """Class that stores index, columns, and shape metadata for wrapping NumPy arrays. Tightly integrated with `vectorbtpro.base.grouping.base.Grouper` for grouping columns. If the underlying object is a Series, pass `[sr.name]` as `columns`. `**kwargs` are passed to `vectorbtpro.base.grouping.base.Grouper`. !!! note This class is meant to be immutable. To change any attribute, use `ArrayWrapper.replace`. Use methods that begin with `get_` to get group-aware results.""" @classmethod def from_obj(cls: tp.Type[ArrayWrapperT], obj: tp.ArrayLike, **kwargs) -> ArrayWrapperT: """Derive metadata from an object.""" from vectorbtpro.base.reshaping import to_pd_array from vectorbtpro.data.base import Data if isinstance(obj, Data): obj = obj.symbol_wrapper if isinstance(obj, Wrapping): obj = obj.wrapper if isinstance(obj, ArrayWrapper): return obj.replace(**kwargs) pd_obj = to_pd_array(obj) index = indexes.get_index(pd_obj, 0) columns = indexes.get_index(pd_obj, 1) ndim = pd_obj.ndim kwargs.pop("index", None) kwargs.pop("columns", None) kwargs.pop("ndim", None) return cls(index, columns, ndim, **kwargs) @classmethod def from_shape( cls: tp.Type[ArrayWrapperT], shape: tp.ShapeLike, index: tp.Optional[tp.IndexLike] = None, columns: tp.Optional[tp.IndexLike] = None, ndim: tp.Optional[int] = None, *args, **kwargs, ) -> ArrayWrapperT: """Derive metadata from shape.""" shape = reshaping.to_tuple_shape(shape) if index is None: index = pd.RangeIndex(stop=shape[0]) if columns is None: columns = pd.RangeIndex(stop=shape[1] if len(shape) > 1 else 1) if ndim is None: ndim = len(shape) return cls(index, columns, ndim, *args, **kwargs) @staticmethod def extract_init_kwargs(**kwargs) -> tp.Tuple[tp.Kwargs, tp.Kwargs]: """Extract keyword arguments that can be passed to `ArrayWrapper` or `Grouper`.""" wrapper_arg_names = get_func_arg_names(ArrayWrapper.__init__) grouper_arg_names = get_func_arg_names(Grouper.__init__) init_kwargs = dict() for k in list(kwargs.keys()): if k in wrapper_arg_names or k in grouper_arg_names: init_kwargs[k] = kwargs.pop(k) return init_kwargs, kwargs @classmethod def resolve_stack_kwargs(cls, *wrappers: tp.MaybeTuple[ArrayWrapperT], **kwargs) -> tp.Kwargs: """Resolve keyword arguments for initializing `ArrayWrapper` after stacking.""" if len(wrappers) == 1: wrappers = wrappers[0] wrappers = list(wrappers) common_keys = set() for wrapper in wrappers: common_keys = common_keys.union(set(wrapper.config.keys())) if "grouper" not in kwargs: common_keys = common_keys.union(set(wrapper.grouper.config.keys())) common_keys.remove("grouper") init_wrapper = wrappers[0] for i in range(1, len(wrappers)): wrapper = wrappers[i] for k in common_keys: if k not in kwargs: same_k = True try: if k in wrapper.config: if not checks.is_deep_equal(init_wrapper.config[k], wrapper.config[k]): same_k = False elif "grouper" not in kwargs and k in wrapper.grouper.config: if not checks.is_deep_equal(init_wrapper.grouper.config[k], wrapper.grouper.config[k]): same_k = False else: same_k = False except KeyError as e: same_k = False if not same_k: raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.") for k in common_keys: if k not in kwargs: if k in init_wrapper.config: kwargs[k] = init_wrapper.config[k] elif "grouper" not in kwargs and k in init_wrapper.grouper.config: kwargs[k] = init_wrapper.grouper.config[k] else: raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.") return kwargs @hybrid_method def row_stack( cls_or_self: tp.MaybeType[ArrayWrapperT], *wrappers: tp.MaybeTuple[ArrayWrapperT], index: tp.Optional[tp.IndexLike] = None, columns: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, group_by: tp.GroupByLike = None, stack_columns: bool = True, index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append", keys: tp.Optional[tp.IndexLike] = None, clean_index_kwargs: tp.KwargsLike = None, verify_integrity: bool = True, **kwargs, ) -> ArrayWrapperT: """Stack multiple `ArrayWrapper` instances along rows. Concatenates indexes using `vectorbtpro.base.indexes.concat_indexes`. Frequency must be the same across all indexes. A custom frequency can be provided via `freq`. If column levels in some instances differ, they will be stacked upon each other. Custom columns can be provided via `columns`. If `group_by` is None, all instances must be either grouped or not, and they must contain the same group values and labels. All instances must contain the same keys and values in their configs and configs of their grouper instances, apart from those arguments provided explicitly via `kwargs`.""" if not isinstance(cls_or_self, type): wrappers = (cls_or_self, *wrappers) cls = type(cls_or_self) else: cls = cls_or_self if len(wrappers) == 1: wrappers = wrappers[0] wrappers = list(wrappers) for wrapper in wrappers: if not checks.is_instance_of(wrapper, ArrayWrapper): raise TypeError("Each object to be merged must be an instance of ArrayWrapper") if index is None: index = concat_indexes( [wrapper.index for wrapper in wrappers], index_concat_method=index_concat_method, keys=keys, clean_index_kwargs=clean_index_kwargs, verify_integrity=verify_integrity, axis=0, ) elif not isinstance(index, pd.Index): index = pd.Index(index) kwargs["index"] = index if freq is None: new_freq = None for wrapper in wrappers: if new_freq is None: new_freq = wrapper.freq else: if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq: raise ValueError("Objects to be merged must have the same frequency") freq = new_freq kwargs["freq"] = freq if columns is None: new_columns = None for wrapper in wrappers: if new_columns is None: new_columns = wrapper.columns else: if not checks.is_index_equal(new_columns, wrapper.columns): if not stack_columns: raise ValueError("Objects to be merged must have the same columns") new_columns = stack_indexes( (new_columns, wrapper.columns), **resolve_dict(clean_index_kwargs), ) columns = new_columns elif not isinstance(columns, pd.Index): columns = pd.Index(columns) kwargs["columns"] = columns if "grouper" in kwargs: if not checks.is_index_equal(columns, kwargs["grouper"].index): raise ValueError("Columns and grouper index must match") if group_by is not None: kwargs["group_by"] = group_by else: if group_by is None: grouped = None for wrapper in wrappers: wrapper_grouped = wrapper.grouper.is_grouped() if grouped is None: grouped = wrapper_grouped else: if grouped is not wrapper_grouped: raise ValueError("Objects to be merged must be either grouped or not") if grouped: new_group_by = None for wrapper in wrappers: wrapper_groups, wrapper_grouped_index = wrapper.grouper.get_groups_and_index() wrapper_group_by = wrapper_grouped_index[wrapper_groups] if new_group_by is None: new_group_by = wrapper_group_by else: if not checks.is_index_equal(new_group_by, wrapper_group_by): raise ValueError("Objects to be merged must have the same groups") group_by = new_group_by else: group_by = False kwargs["group_by"] = group_by if "ndim" not in kwargs: ndim = None for wrapper in wrappers: if ndim is None or wrapper.ndim > 1: ndim = wrapper.ndim kwargs["ndim"] = ndim return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs)) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[ArrayWrapperT], *wrappers: tp.MaybeTuple[ArrayWrapperT], index: tp.Optional[tp.IndexLike] = None, columns: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, group_by: tp.GroupByLike = None, union_index: bool = True, col_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append", group_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = ("append", "factorize_each"), keys: tp.Optional[tp.IndexLike] = None, clean_index_kwargs: tp.KwargsLike = None, verify_integrity: bool = True, **kwargs, ) -> ArrayWrapperT: """Stack multiple `ArrayWrapper` instances along columns. If indexes are the same in each wrapper index, will use that index. If indexes differ and `union_index` is True, they will be merged into a single one by the set union operation. Otherwise, an error will be raised. The merged index must have no duplicates or mixed data, and must be monotonically increasing. A custom index can be provided via `index`. Frequency must be the same across all indexes. A custom frequency can be provided via `freq`. Concatenates columns and groups using `vectorbtpro.base.indexes.concat_indexes`. If any of the instances has `column_only_select` being enabled, the final wrapper will also enable it. If any of the instances has `group_select` or other grouping-related flags being disabled, the final wrapper will also disable them. All instances must contain the same keys and values in their configs and configs of their grouper instances, apart from those arguments provided explicitly via `kwargs`.""" if not isinstance(cls_or_self, type): wrappers = (cls_or_self, *wrappers) cls = type(cls_or_self) else: cls = cls_or_self if len(wrappers) == 1: wrappers = wrappers[0] wrappers = list(wrappers) for wrapper in wrappers: if not checks.is_instance_of(wrapper, ArrayWrapper): raise TypeError("Each object to be merged must be an instance of ArrayWrapper") for wrapper in wrappers: if wrapper.index.has_duplicates: raise ValueError("Index of some objects to be merged contains duplicates") if index is None: new_index = None for wrapper in wrappers: if new_index is None: new_index = wrapper.index else: if not checks.is_index_equal(new_index, wrapper.index): if not union_index: raise ValueError( "Objects to be merged must have the same index. " "Use union_index=True to merge index as well." ) else: if new_index.dtype != wrapper.index.dtype: raise ValueError("Indexes to be merged must have the same data type") new_index = new_index.union(wrapper.index) if not new_index.is_monotonic_increasing: raise ValueError("Merged index must be monotonically increasing") index = new_index elif not isinstance(index, pd.Index): index = pd.Index(index) kwargs["index"] = index if freq is None: new_freq = None for wrapper in wrappers: if new_freq is None: new_freq = wrapper.freq else: if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq: raise ValueError("Objects to be merged must have the same frequency") freq = new_freq kwargs["freq"] = freq if columns is None: columns = concat_indexes( [wrapper.columns for wrapper in wrappers], index_concat_method=col_concat_method, keys=keys, clean_index_kwargs=clean_index_kwargs, verify_integrity=verify_integrity, axis=1, ) elif not isinstance(columns, pd.Index): columns = pd.Index(columns) kwargs["columns"] = columns if "grouper" in kwargs: if not checks.is_index_equal(columns, kwargs["grouper"].index): raise ValueError("Columns and grouper index must match") if group_by is not None: kwargs["group_by"] = group_by else: if group_by is None: any_grouped = False for wrapper in wrappers: if wrapper.grouper.is_grouped(): any_grouped = True break if any_grouped: group_by = concat_indexes( [wrapper.grouper.get_stretched_index() for wrapper in wrappers], index_concat_method=group_concat_method, keys=keys, clean_index_kwargs=clean_index_kwargs, verify_integrity=verify_integrity, axis=2, ) else: group_by = False kwargs["group_by"] = group_by if "ndim" not in kwargs: kwargs["ndim"] = 2 if "grouped_ndim" not in kwargs: kwargs["grouped_ndim"] = None if "column_only_select" not in kwargs: column_only_select = None for wrapper in wrappers: if column_only_select is None or wrapper.column_only_select: column_only_select = wrapper.column_only_select kwargs["column_only_select"] = column_only_select if "range_only_select" not in kwargs: range_only_select = None for wrapper in wrappers: if range_only_select is None or wrapper.range_only_select: range_only_select = wrapper.range_only_select kwargs["range_only_select"] = range_only_select if "group_select" not in kwargs: group_select = None for wrapper in wrappers: if group_select is None or not wrapper.group_select: group_select = wrapper.group_select kwargs["group_select"] = group_select if "grouper" not in kwargs: if "allow_enable" not in kwargs: allow_enable = None for wrapper in wrappers: if allow_enable is None or not wrapper.grouper.allow_enable: allow_enable = wrapper.grouper.allow_enable kwargs["allow_enable"] = allow_enable if "allow_disable" not in kwargs: allow_disable = None for wrapper in wrappers: if allow_disable is None or not wrapper.grouper.allow_disable: allow_disable = wrapper.grouper.allow_disable kwargs["allow_disable"] = allow_disable if "allow_modify" not in kwargs: allow_modify = None for wrapper in wrappers: if allow_modify is None or not wrapper.grouper.allow_modify: allow_modify = wrapper.grouper.allow_modify kwargs["allow_modify"] = allow_modify return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs)) def __init__( self, index: tp.IndexLike, columns: tp.Optional[tp.IndexLike] = None, ndim: tp.Optional[int] = None, freq: tp.Optional[tp.FrequencyLike] = None, parse_index: tp.Optional[bool] = None, column_only_select: tp.Optional[bool] = None, range_only_select: tp.Optional[bool] = None, group_select: tp.Optional[bool] = None, grouped_ndim: tp.Optional[int] = None, grouper: tp.Optional[Grouper] = None, **kwargs, ) -> None: checks.assert_not_none(index, arg_name="index") index = dt.prepare_dt_index(index, parse_index=parse_index) if columns is None: columns = [None] if not isinstance(columns, pd.Index): columns = pd.Index(columns) if ndim is None: if len(columns) == 1 and not isinstance(columns, pd.MultiIndex): ndim = 1 else: ndim = 2 else: if len(columns) > 1: ndim = 2 grouper_arg_names = get_func_arg_names(Grouper.__init__) grouper_kwargs = dict() for k in list(kwargs.keys()): if k in grouper_arg_names: grouper_kwargs[k] = kwargs.pop(k) if grouper is None: grouper = Grouper(columns, **grouper_kwargs) elif not checks.is_index_equal(columns, grouper.index) or len(grouper_kwargs) > 0: grouper = grouper.replace(index=columns, **grouper_kwargs) HasWrapper.__init__(self) Configured.__init__( self, index=index, columns=columns, ndim=ndim, freq=freq, parse_index=parse_index, column_only_select=column_only_select, range_only_select=range_only_select, group_select=group_select, grouped_ndim=grouped_ndim, grouper=grouper, **kwargs, ) self._index = index self._columns = columns self._ndim = ndim self._freq = freq self._parse_index = parse_index self._column_only_select = column_only_select self._range_only_select = range_only_select self._group_select = group_select self._grouper = grouper self._grouped_ndim = grouped_ndim def indexing_func_meta( self: ArrayWrapperT, pd_indexing_func: tp.PandasIndexingFunc, index: tp.Optional[tp.IndexLike] = None, columns: tp.Optional[tp.IndexLike] = None, column_only_select: tp.Optional[bool] = None, range_only_select: tp.Optional[bool] = None, group_select: tp.Optional[bool] = None, return_slices: bool = True, return_none_slices: bool = True, return_scalars: bool = True, group_by: tp.GroupByLike = None, wrapper_kwargs: tp.KwargsLike = None, ) -> dict: """Perform indexing on `ArrayWrapper` and also return metadata. Takes into account column grouping. Flipping rows and columns is not allowed. If one row is selected, the result will still be a Series when indexing a Series and a DataFrame when indexing a DataFrame. Set `column_only_select` to True to index the array wrapper as a Series of columns/groups. This way, selection of index (axis 0) can be avoided. Set `range_only_select` to True to allow selection of rows only using slices. Set `group_select` to True to allow selection of groups. Otherwise, indexing is performed on columns, even if grouping is enabled. Takes effect only if grouping is enabled. Returns the new array wrapper, row indices, column indices, and group indices. If `return_slices` is True (default), indices will be returned as a slice if they were identified as a range. If `return_none_slices` is True (default), indices will be returned as a slice `(None, None, None)` if the axis hasn't been changed. !!! note If `column_only_select` is True, make sure to index the array wrapper as a Series of columns rather than a DataFrame. For example, the operation `.iloc[:, :2]` should become `.iloc[:2]`. Operations are not allowed if the object is already a Series and thus has only one column/group.""" if column_only_select is None: column_only_select = self.column_only_select if range_only_select is None: range_only_select = self.range_only_select if group_select is None: group_select = self.group_select if wrapper_kwargs is None: wrapper_kwargs = {} _self = self.regroup(group_by) group_select = group_select and _self.grouper.is_grouped() if index is None: index = _self.index if not isinstance(index, pd.Index): index = pd.Index(index) if columns is None: if group_select: columns = _self.get_columns() else: columns = _self.columns if not isinstance(columns, pd.Index): columns = pd.Index(columns) if group_select: # Groups as columns i_wrapper = ArrayWrapper(index, columns, _self.get_ndim()) else: # Columns as columns i_wrapper = ArrayWrapper(index, columns, _self.ndim) n_rows = len(index) n_cols = len(columns) def _resolve_arr(arr, n): if checks.is_np_array(arr) and is_range(arr): if arr[0] == 0 and arr[-1] == n - 1: if return_none_slices: return slice(None, None, None), False return arr, False if return_slices: return slice(arr[0], arr[-1] + 1, None), True return arr, True if isinstance(arr, np.integer): arr = arr.item() columns_changed = True if isinstance(arr, int): if arr == 0 and n == 1: columns_changed = False if not return_scalars: arr = np.array([arr]) return arr, columns_changed if column_only_select: if i_wrapper.ndim == 1: raise IndexingError("Columns only: This instance already contains one column of data") try: col_mapper = pd_indexing_func(i_wrapper.wrap_reduced(np.arange(n_cols), columns=columns)) except pd.core.indexing.IndexingError as e: warn("Columns only: Make sure to treat this instance as a Series of columns rather than a DataFrame") raise e if checks.is_series(col_mapper): new_columns = col_mapper.index col_idxs = col_mapper.values new_ndim = 2 else: new_columns = columns[[col_mapper]] col_idxs = col_mapper new_ndim = 1 new_index = index row_idxs = np.arange(len(index)) else: init_row_mapper_values = reshaping.broadcast_array_to(np.arange(n_rows)[:, None], (n_rows, n_cols)) init_row_mapper = i_wrapper.wrap(init_row_mapper_values, index=index, columns=columns) row_mapper = pd_indexing_func(init_row_mapper) if i_wrapper.ndim == 1: if not checks.is_series(row_mapper): row_idxs = np.array([row_mapper]) new_index = index[row_idxs] else: row_idxs = row_mapper.values new_index = indexes.get_index(row_mapper, 0) col_idxs = 0 new_columns = columns new_ndim = 1 else: init_col_mapper_values = reshaping.broadcast_array_to(np.arange(n_cols)[None], (n_rows, n_cols)) init_col_mapper = i_wrapper.wrap(init_col_mapper_values, index=index, columns=columns) col_mapper = pd_indexing_func(init_col_mapper) if checks.is_frame(col_mapper): # Multiple rows and columns selected row_idxs = row_mapper.values[:, 0] col_idxs = col_mapper.values[0] new_index = indexes.get_index(row_mapper, 0) new_columns = indexes.get_index(col_mapper, 1) new_ndim = 2 elif checks.is_series(col_mapper): multi_index = isinstance(index, pd.MultiIndex) multi_columns = isinstance(columns, pd.MultiIndex) multi_name = isinstance(col_mapper.name, tuple) if multi_index and multi_name and col_mapper.name in index: one_row = True elif not multi_index and not multi_name and col_mapper.name in index: one_row = True else: one_row = False if multi_columns and multi_name and col_mapper.name in columns: one_col = True elif not multi_columns and not multi_name and col_mapper.name in columns: one_col = True else: one_col = False if (one_row and one_col) or (not one_row and not one_col): one_row = np.all(row_mapper.values == row_mapper.values.item(0)) one_col = np.all(col_mapper.values == col_mapper.values.item(0)) if (one_row and one_col) or (not one_row and not one_col): raise IndexingError("Could not parse indexing operation") if one_row: # One row selected row_idxs = row_mapper.values[[0]] col_idxs = col_mapper.values new_index = index[row_idxs] new_columns = indexes.get_index(col_mapper, 0) new_ndim = 2 else: # One column selected row_idxs = row_mapper.values col_idxs = col_mapper.values[0] new_index = indexes.get_index(row_mapper, 0) new_columns = columns[[col_idxs]] new_ndim = 1 else: # One row and column selected row_idxs = np.array([row_mapper]) col_idxs = col_mapper new_index = index[row_idxs] new_columns = columns[[col_idxs]] new_ndim = 1 if _self.grouper.is_grouped(): # Grouping enabled if np.asarray(row_idxs).ndim == 0: raise IndexingError("Flipping index and columns is not allowed") if group_select: # Selection based on groups # Get indices of columns corresponding to selected groups group_idxs = col_idxs col_idxs, new_groups = _self.grouper.select_groups(group_idxs) ungrouped_columns = _self.columns[col_idxs] if new_ndim == 1 and len(ungrouped_columns) == 1: ungrouped_ndim = 1 col_idxs = col_idxs[0] else: ungrouped_ndim = 2 row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0]) if range_only_select and rows_changed: if not isinstance(row_idxs, slice): raise ValueError("Rows can be selected only by slicing") if row_idxs.step not in (1, None): raise ValueError("Slice for selecting rows must have a step of 1 or None") col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1]) group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1]) return dict( new_wrapper=_self.replace( **merge_dicts( dict( index=new_index, columns=ungrouped_columns, ndim=ungrouped_ndim, grouped_ndim=new_ndim, group_by=new_columns[new_groups], ), wrapper_kwargs, ) ), row_idxs=row_idxs, rows_changed=rows_changed, col_idxs=col_idxs, columns_changed=columns_changed, group_idxs=group_idxs, groups_changed=groups_changed, ) # Selection based on columns group_idxs = _self.grouper.get_groups()[col_idxs] new_group_by = _self.grouper.group_by[reshaping.to_1d_array(col_idxs)] row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0]) if range_only_select and rows_changed: if not isinstance(row_idxs, slice): raise ValueError("Rows can be selected only by slicing") if row_idxs.step not in (1, None): raise ValueError("Slice for selecting rows must have a step of 1 or None") col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1]) group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1]) return dict( new_wrapper=_self.replace( **merge_dicts( dict( index=new_index, columns=new_columns, ndim=new_ndim, grouped_ndim=None, group_by=new_group_by, ), wrapper_kwargs, ) ), row_idxs=row_idxs, rows_changed=rows_changed, col_idxs=col_idxs, columns_changed=columns_changed, group_idxs=group_idxs, groups_changed=groups_changed, ) # Grouping disabled row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0]) if range_only_select and rows_changed: if not isinstance(row_idxs, slice): raise ValueError("Rows can be selected only by slicing") if row_idxs.step not in (1, None): raise ValueError("Slice for selecting rows must have a step of 1 or None") col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1]) return dict( new_wrapper=_self.replace( **merge_dicts( dict( index=new_index, columns=new_columns, ndim=new_ndim, grouped_ndim=None, group_by=None, ), wrapper_kwargs, ) ), row_idxs=row_idxs, rows_changed=rows_changed, col_idxs=col_idxs, columns_changed=columns_changed, group_idxs=col_idxs, groups_changed=columns_changed, ) def indexing_func(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT: """Perform indexing on `ArrayWrapper`.""" return self.indexing_func_meta(*args, **kwargs)["new_wrapper"] @staticmethod def select_from_flex_array( arr: tp.ArrayLike, row_idxs: tp.Union[int, tp.Array1d, slice] = None, col_idxs: tp.Union[int, tp.Array1d, slice] = None, rows_changed: bool = True, columns_changed: bool = True, rotate_rows: bool = False, rotate_cols: bool = True, ) -> tp.Array2d: """Select rows and columns from a flexible array. Always returns a 2-dim NumPy array.""" new_arr = arr_2d = reshaping.to_2d_array(arr) if row_idxs is not None and rows_changed: if arr_2d.shape[0] > 1: if isinstance(row_idxs, slice): max_idx = row_idxs.stop - 1 else: row_idxs = reshaping.to_1d_array(row_idxs) max_idx = np.max(row_idxs) if arr_2d.shape[0] <= max_idx: if rotate_rows and not isinstance(row_idxs, slice): new_arr = new_arr[row_idxs % arr_2d.shape[0], :] else: new_arr = new_arr[row_idxs, :] else: new_arr = new_arr[row_idxs, :] if col_idxs is not None and columns_changed: if arr_2d.shape[1] > 1: if isinstance(col_idxs, slice): max_idx = col_idxs.stop - 1 else: col_idxs = reshaping.to_1d_array(col_idxs) max_idx = np.max(col_idxs) if arr_2d.shape[1] <= max_idx: if rotate_cols and not isinstance(col_idxs, slice): new_arr = new_arr[:, col_idxs % arr_2d.shape[1]] else: new_arr = new_arr[:, col_idxs] else: new_arr = new_arr[:, col_idxs] return new_arr def get_resampler(self, *args, **kwargs) -> tp.Union[Resampler, tp.PandasResampler]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_resampler`.""" return self.index_acc.get_resampler(*args, **kwargs) def resample_meta(self: ArrayWrapperT, *args, wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> dict: """Perform resampling on `ArrayWrapper` and also return metadata. `*args` and `**kwargs` are passed to `ArrayWrapper.get_resampler`.""" resampler = self.get_resampler(*args, **kwargs) if isinstance(resampler, Resampler): _resampler = resampler else: _resampler = Resampler.from_pd_resampler(resampler) if wrapper_kwargs is None: wrapper_kwargs = {} if "index" not in wrapper_kwargs: wrapper_kwargs["index"] = _resampler.target_index if "freq" not in wrapper_kwargs: wrapper_kwargs["freq"] = dt.infer_index_freq(wrapper_kwargs["index"], freq=_resampler.target_freq) new_wrapper = self.replace(**wrapper_kwargs) return dict(resampler=resampler, new_wrapper=new_wrapper) def resample(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT: """Perform resampling on `ArrayWrapper`. Uses `ArrayWrapper.resample_meta`.""" return self.resample_meta(*args, **kwargs)["new_wrapper"] @property def wrapper(self) -> "ArrayWrapper": return self @property def index(self) -> tp.Index: """Index.""" return self._index @cached_property(whitelist=True) def index_acc(self) -> BaseIDXAccessorT: """Get index accessor of the type `vectorbtpro.base.accessors.BaseIDXAccessor`.""" from vectorbtpro.base.accessors import BaseIDXAccessor return BaseIDXAccessor(self.index, freq=self._freq) @property def ns_index(self) -> tp.Array1d: """See `vectorbtpro.base.accessors.BaseIDXAccessor.to_ns`.""" return self.index_acc.to_ns() def get_period_ns_index(self, *args, **kwargs) -> tp.Array1d: """See `vectorbtpro.base.accessors.BaseIDXAccessor.to_period_ns`.""" return self.index_acc.to_period_ns(*args, **kwargs) @property def columns(self) -> tp.Index: """Columns.""" return self._columns def get_columns(self, group_by: tp.GroupByLike = None) -> tp.Index: """Get group-aware `ArrayWrapper.columns`.""" return self.resolve(group_by=group_by).columns @property def name(self) -> tp.Any: """Name.""" if self.ndim == 1: if self.columns[0] == 0: return None return self.columns[0] return None def get_name(self, group_by: tp.GroupByLike = None) -> tp.Any: """Get group-aware `ArrayWrapper.name`.""" return self.resolve(group_by=group_by).name @property def ndim(self) -> int: """Number of dimensions.""" return self._ndim def get_ndim(self, group_by: tp.GroupByLike = None) -> int: """Get group-aware `ArrayWrapper.ndim`.""" return self.resolve(group_by=group_by).ndim @property def shape(self) -> tp.Shape: """Shape.""" if self.ndim == 1: return (len(self.index),) return len(self.index), len(self.columns) def get_shape(self, group_by: tp.GroupByLike = None) -> tp.Shape: """Get group-aware `ArrayWrapper.shape`.""" return self.resolve(group_by=group_by).shape @property def shape_2d(self) -> tp.Shape: """Shape as if the instance was two-dimensional.""" if self.ndim == 1: return self.shape[0], 1 return self.shape def get_shape_2d(self, group_by: tp.GroupByLike = None) -> tp.Shape: """Get group-aware `ArrayWrapper.shape_2d`.""" return self.resolve(group_by=group_by).shape_2d def get_freq(self, *args, **kwargs) -> tp.Union[None, float, tp.PandasFrequency]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_freq`.""" return self.index_acc.get_freq(*args, **kwargs) @property def freq(self) -> tp.Optional[tp.PandasFrequency]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.freq`.""" return self.index_acc.freq @property def ns_freq(self) -> tp.Optional[int]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.ns_freq`.""" return self.index_acc.ns_freq @property def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.any_freq`.""" return self.index_acc.any_freq @property def periods(self) -> int: """See `vectorbtpro.base.accessors.BaseIDXAccessor.periods`.""" return self.index_acc.periods @property def dt_periods(self) -> float: """See `vectorbtpro.base.accessors.BaseIDXAccessor.dt_periods`.""" return self.index_acc.dt_periods def arr_to_timedelta(self, *args, **kwargs) -> tp.Union[pd.Index, tp.MaybeArray]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.arr_to_timedelta`.""" return self.index_acc.arr_to_timedelta(*args, **kwargs) @property def parse_index(self) -> tp.Optional[bool]: """Whether to try to convert the index into a datetime index. Applied during the initialization and passed to `vectorbtpro.utils.datetime_.prepare_dt_index`.""" return self._parse_index @property def column_only_select(self) -> bool: from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] column_only_select = self._column_only_select if column_only_select is None: column_only_select = wrapping_cfg["column_only_select"] return column_only_select @property def range_only_select(self) -> bool: from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] range_only_select = self._range_only_select if range_only_select is None: range_only_select = wrapping_cfg["range_only_select"] return range_only_select @property def group_select(self) -> bool: from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] group_select = self._group_select if group_select is None: group_select = wrapping_cfg["group_select"] return group_select @property def grouper(self) -> Grouper: """Column grouper.""" return self._grouper @property def grouped_ndim(self) -> int: """Number of dimensions under column grouping.""" if self._grouped_ndim is None: if self.grouper.is_grouped(): return 2 if self.grouper.get_group_count() > 1 else 1 return self.ndim return self._grouped_ndim @cached_method(whitelist=True) def regroup(self: ArrayWrapperT, group_by: tp.GroupByLike, **kwargs) -> ArrayWrapperT: """Regroup this instance. Only creates a new instance if grouping has changed, otherwise returns itself.""" if self.grouper.is_grouping_changed(group_by=group_by): self.grouper.check_group_by(group_by=group_by) grouped_ndim = None if self.grouper.is_grouped(group_by=group_by): if not self.grouper.is_group_count_changed(group_by=group_by): grouped_ndim = self.grouped_ndim return self.replace(grouped_ndim=grouped_ndim, group_by=group_by, **kwargs) if len(kwargs) > 0: return self.replace(**kwargs) return self # important for keeping cache def flip(self: ArrayWrapperT, **kwargs) -> ArrayWrapperT: """Flip index and columns.""" if "grouper" not in kwargs: kwargs["grouper"] = None return self.replace(index=self.columns, columns=self.index, **kwargs) @cached_method(whitelist=True) def resolve(self: ArrayWrapperT, group_by: tp.GroupByLike = None, **kwargs) -> ArrayWrapperT: """Resolve this instance. Replaces columns and other metadata with groups.""" _self = self.regroup(group_by=group_by, **kwargs) if _self.grouper.is_grouped(): return _self.replace( columns=_self.grouper.get_index(), ndim=_self.grouped_ndim, grouped_ndim=None, group_by=None, ) return _self # important for keeping cache def get_index_grouper(self, *args, **kwargs) -> Grouper: """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.""" return self.index_acc.get_grouper(*args, **kwargs) def wrap( self, arr: tp.ArrayLike, group_by: tp.GroupByLike = None, index: tp.Optional[tp.IndexLike] = None, columns: tp.Optional[tp.IndexLike] = None, zero_to_none: tp.Optional[bool] = None, force_2d: bool = False, fillna: tp.Optional[tp.Scalar] = None, dtype: tp.Optional[tp.PandasDTypeLike] = None, min_precision: tp.Union[None, int, str] = None, max_precision: tp.Union[None, int, str] = None, prec_float_only: tp.Optional[bool] = None, prec_check_bounds: tp.Optional[bool] = None, prec_strict: tp.Optional[bool] = None, to_timedelta: bool = False, to_index: bool = False, silence_warnings: tp.Optional[bool] = None, ) -> tp.SeriesFrame: """Wrap a NumPy array using the stored metadata. Runs the following pipeline: 1) Converts to NumPy array 2) Fills NaN (optional) 3) Wraps using index, columns, and dtype (optional) 4) Converts to index (optional) 5) Converts to timedelta using `ArrayWrapper.arr_to_timedelta` (optional)""" from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] if zero_to_none is None: zero_to_none = wrapping_cfg["zero_to_none"] if min_precision is None: min_precision = wrapping_cfg["min_precision"] if max_precision is None: max_precision = wrapping_cfg["max_precision"] if prec_float_only is None: prec_float_only = wrapping_cfg["prec_float_only"] if prec_check_bounds is None: prec_check_bounds = wrapping_cfg["prec_check_bounds"] if prec_strict is None: prec_strict = wrapping_cfg["prec_strict"] if silence_warnings is None: silence_warnings = wrapping_cfg["silence_warnings"] _self = self.resolve(group_by=group_by) if index is None: index = _self.index if not isinstance(index, pd.Index): index = pd.Index(index) if columns is None: columns = _self.columns if not isinstance(columns, pd.Index): columns = pd.Index(columns) if len(columns) == 1: name = columns[0] if zero_to_none and name == 0: # was a Series before name = None else: name = None def _apply_dtype(obj): if dtype is None: return obj return obj.astype(dtype, errors="ignore") def _wrap(arr): orig_arr = arr arr = np.asarray(arr) if fillna is not None: arr[pd.isnull(arr)] = fillna shape_2d = (arr.shape[0] if arr.ndim > 0 else 1, arr.shape[1] if arr.ndim > 1 else 1) target_shape_2d = (len(index), len(columns)) if shape_2d != target_shape_2d: if isinstance(orig_arr, (pd.Series, pd.DataFrame)): arr = reshaping.align_pd_arrays(orig_arr, to_index=index, to_columns=columns).values arr = reshaping.broadcast_array_to(arr, target_shape_2d) arr = reshaping.soft_to_ndim(arr, self.ndim) if min_precision is not None: arr = cast_to_min_precision( arr, min_precision, float_only=prec_float_only, ) if max_precision is not None: arr = cast_to_max_precision( arr, max_precision, float_only=prec_float_only, check_bounds=prec_check_bounds, strict=prec_strict, ) if arr.ndim == 1: if force_2d: return _apply_dtype(pd.DataFrame(arr[:, None], index=index, columns=columns)) return _apply_dtype(pd.Series(arr, index=index, name=name)) if arr.ndim == 2: if not force_2d and arr.shape[1] == 1 and _self.ndim == 1: return _apply_dtype(pd.Series(arr[:, 0], index=index, name=name)) return _apply_dtype(pd.DataFrame(arr, index=index, columns=columns)) raise ValueError(f"{arr.ndim}-d input is not supported") out = _wrap(arr) if to_index: # Convert to index if checks.is_series(out): out = out.map(lambda x: self.index[x] if x != -1 else np.nan) else: out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan) if to_timedelta: # Convert to timedelta out = self.arr_to_timedelta(out, silence_warnings=silence_warnings) return out def wrap_reduced( self, arr: tp.ArrayLike, group_by: tp.GroupByLike = None, name_or_index: tp.NameIndex = None, columns: tp.Optional[tp.IndexLike] = None, force_1d: bool = False, fillna: tp.Optional[tp.Scalar] = None, dtype: tp.Optional[tp.PandasDTypeLike] = None, to_timedelta: bool = False, to_index: bool = False, silence_warnings: tp.Optional[bool] = None, ) -> tp.MaybeSeriesFrame: """Wrap result of reduction. `name_or_index` can be the name of the resulting series if reducing to a scalar per column, or the index of the resulting series/dataframe if reducing to an array per column. `columns` can be set to override object's default columns. See `ArrayWrapper.wrap` for the pipeline.""" from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] if silence_warnings is None: silence_warnings = wrapping_cfg["silence_warnings"] _self = self.resolve(group_by=group_by) if columns is None: columns = _self.columns if not isinstance(columns, pd.Index): columns = pd.Index(columns) if to_index: if dtype is None: dtype = int_ if fillna is None: fillna = -1 def _apply_dtype(obj): if dtype is None: return obj return obj.astype(dtype, errors="ignore") def _wrap_reduced(arr): nonlocal name_or_index if isinstance(arr, dict): arr = reshaping.to_pd_array(arr) if isinstance(arr, pd.Series): if not checks.is_index_equal(arr.index, columns): arr = arr.iloc[indexes.align_indexes(arr.index, columns)[0]] arr = np.asarray(arr) if force_1d and arr.ndim == 0: arr = arr[None] if fillna is not None: if arr.ndim == 0: if pd.isnull(arr): arr = fillna else: arr[pd.isnull(arr)] = fillna if arr.ndim == 0: # Scalar per Series/DataFrame return _apply_dtype(pd.Series(arr[None]))[0] if arr.ndim == 1: if not force_1d and _self.ndim == 1: if arr.shape[0] == 1: # Scalar per Series/DataFrame with one column return _apply_dtype(pd.Series(arr))[0] # Array per Series sr_name = columns[0] if sr_name == 0: sr_name = None if isinstance(name_or_index, str): name_or_index = None return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name)) # Scalar per column in DataFrame if arr.shape[0] == 1 and len(columns) > 1: arr = reshaping.broadcast_array_to(arr, len(columns)) return _apply_dtype(pd.Series(arr, index=columns, name=name_or_index)) if arr.ndim == 2: if arr.shape[1] == 1 and _self.ndim == 1: arr = reshaping.soft_to_ndim(arr, 1) # Array per Series sr_name = columns[0] if sr_name == 0: sr_name = None if isinstance(name_or_index, str): name_or_index = None return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name)) # Array per column in DataFrame if isinstance(name_or_index, str): name_or_index = None if arr.shape[0] == 1 and len(columns) > 1: arr = reshaping.broadcast_array_to(arr, (arr.shape[0], len(columns))) return _apply_dtype(pd.DataFrame(arr, index=name_or_index, columns=columns)) raise ValueError(f"{arr.ndim}-d input is not supported") out = _wrap_reduced(arr) if to_index: # Convert to index if checks.is_series(out): out = out.map(lambda x: self.index[x] if x != -1 else np.nan) elif checks.is_frame(out): out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan) else: out = self.index[out] if out != -1 else np.nan if to_timedelta: # Convert to timedelta out = self.arr_to_timedelta(out, silence_warnings=silence_warnings) return out def concat_arrs( self, *objs: tp.ArrayLike, group_by: tp.GroupByLike = None, wrap: bool = True, **kwargs, ) -> tp.AnyArray1d: """Stack reduced objects along columns and wrap the final object.""" from vectorbtpro.base.merging import concat_arrays if len(objs) == 1: objs = objs[0] objs = list(objs) new_objs = [] for obj in objs: new_objs.append(reshaping.to_1d_array(obj)) stacked_obj = concat_arrays(new_objs) if wrap: _self = self.resolve(group_by=group_by) return _self.wrap_reduced(stacked_obj, **kwargs) return stacked_obj def row_stack_arrs( self, *objs: tp.ArrayLike, group_by: tp.GroupByLike = None, wrap: bool = True, **kwargs, ) -> tp.AnyArray: """Stack objects along rows and wrap the final object.""" from vectorbtpro.base.merging import row_stack_arrays _self = self.resolve(group_by=group_by) if len(objs) == 1: objs = objs[0] objs = list(objs) new_objs = [] for obj in objs: obj = reshaping.to_2d_array(obj) if obj.shape[1] != _self.shape_2d[1]: if obj.shape[1] != 1: raise ValueError(f"Cannot broadcast {obj.shape[1]} to {_self.shape_2d[1]} columns") obj = np.repeat(obj, _self.shape_2d[1], axis=1) new_objs.append(obj) stacked_obj = row_stack_arrays(new_objs) if wrap: return _self.wrap(stacked_obj, **kwargs) return stacked_obj def column_stack_arrs( self, *objs: tp.ArrayLike, reindex_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, wrap: bool = True, **kwargs, ) -> tp.AnyArray2d: """Stack objects along columns and wrap the final object. `reindex_kwargs` will be passed to [pandas.DataFrame.reindex](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.reindex.html).""" from vectorbtpro.base.merging import column_stack_arrays _self = self.resolve(group_by=group_by) if len(objs) == 1: objs = objs[0] objs = list(objs) new_objs = [] for obj in objs: if not checks.is_index_equal(obj.index, _self.index, check_names=False): was_bool = (isinstance(obj, pd.Series) and obj.dtype == "bool") or ( isinstance(obj, pd.DataFrame) and (obj.dtypes == "bool").all() ) obj = obj.reindex(_self.index, **resolve_dict(reindex_kwargs)) is_object = (isinstance(obj, pd.Series) and obj.dtype == "object") or ( isinstance(obj, pd.DataFrame) and (obj.dtypes == "object").all() ) if was_bool and is_object: obj = obj.astype(None) new_objs.append(reshaping.to_2d_array(obj)) stacked_obj = column_stack_arrays(new_objs) if wrap: return _self.wrap(stacked_obj, **kwargs) return stacked_obj def dummy(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: """Create a dummy Series/DataFrame.""" _self = self.resolve(group_by=group_by) return _self.wrap(np.empty(_self.shape), **kwargs) def fill(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: """Fill a Series/DataFrame.""" _self = self.resolve(group_by=group_by) return _self.wrap(np.full(_self.shape_2d, fill_value), **kwargs) def fill_reduced(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: """Fill a reduced Series/DataFrame.""" _self = self.resolve(group_by=group_by) return _self.wrap_reduced(np.full(_self.shape_2d[1], fill_value), **kwargs) def apply_to_index( self: ArrayWrapperT, apply_func: tp.Callable, *args, axis: tp.Optional[int] = None, **kwargs, ) -> ArrayWrapperT: if axis is None: axis = 0 if self.ndim == 1 else 1 if self.ndim == 1 and axis == 1: raise TypeError("Axis 1 is not supported for one dimension") checks.assert_in(axis, (0, 1)) if axis == 1: return self.replace(columns=apply_func(self.columns, *args, **kwargs)) return self.replace(index=apply_func(self.index, *args, **kwargs)) def get_index_points(self, *args, **kwargs) -> tp.Array1d: """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_points`.""" return self.index_acc.get_points(*args, **kwargs) def get_index_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]: """See `vectorbtpro.base.accessors.BaseIDXAccessor.get_ranges`.""" return self.index_acc.get_ranges(*args, **kwargs) def fill_and_set( self, idx_setter: tp.Union[index_dict, IdxSetter, IdxSetterFactory], keep_flex: bool = False, fill_value: tp.Scalar = np.nan, **kwargs, ) -> tp.AnyArray: """Fill a new array using an index object such as `vectorbtpro.base.indexing.index_dict`. Will be wrapped with `vectorbtpro.base.indexing.IdxSetter` if not already. Will call `vectorbtpro.base.indexing.IdxSetter.fill_and_set`. Usage: * Set a single row: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", periods=5) >>> columns = pd.Index(["a", "b", "c"]) >>> wrapper = vbt.ArrayWrapper(index, columns) >>> wrapper.fill_and_set(vbt.index_dict({ ... 1: 2 ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 2.0 2.0 2.0 2020-01-03 NaN NaN NaN 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... "2020-01-02": 2 ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 2.0 2.0 2.0 2020-01-03 NaN NaN NaN 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... "2020-01-02": [1, 2, 3] ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 1.0 2.0 3.0 2020-01-03 NaN NaN NaN 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN ``` * Set multiple rows: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... (1, 3): [2, 3] ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 2.0 2.0 2.0 2020-01-03 NaN NaN NaN 2020-01-04 3.0 3.0 3.0 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... ("2020-01-02", "2020-01-04"): [[1, 2, 3], [4, 5, 6]] ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 1.0 2.0 3.0 2020-01-03 NaN NaN NaN 2020-01-04 4.0 5.0 6.0 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... ("2020-01-02", "2020-01-04"): [[1, 2, 3]] ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 1.0 2.0 3.0 2020-01-03 NaN NaN NaN 2020-01-04 1.0 2.0 3.0 2020-01-05 NaN NaN NaN ``` * Set rows using slices: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.hslice(1, 3): 2 ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 2.0 2.0 2.0 2020-01-03 2.0 2.0 2.0 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.hslice("2020-01-02", "2020-01-04"): 2 ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 2.0 2.0 2.0 2020-01-03 2.0 2.0 2.0 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... ((0, 2), (3, 5)): [[1], [2]] ... })) a b c 2020-01-01 1.0 1.0 1.0 2020-01-02 1.0 1.0 1.0 2020-01-03 NaN NaN NaN 2020-01-04 2.0 2.0 2.0 2020-01-05 2.0 2.0 2.0 >>> wrapper.fill_and_set(vbt.index_dict({ ... ((0, 2), (3, 5)): [[1, 2, 3], [4, 5, 6]] ... })) a b c 2020-01-01 1.0 2.0 3.0 2020-01-02 1.0 2.0 3.0 2020-01-03 NaN NaN NaN 2020-01-04 4.0 5.0 6.0 2020-01-05 4.0 5.0 6.0 ``` * Set rows using index points: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.pointidx(every="2D"): 2 ... })) a b c 2020-01-01 2.0 2.0 2.0 2020-01-02 NaN NaN NaN 2020-01-03 2.0 2.0 2.0 2020-01-04 NaN NaN NaN 2020-01-05 2.0 2.0 2.0 ``` * Set rows using index ranges: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.rangeidx( ... start=("2020-01-01", "2020-01-03"), ... end=("2020-01-02", "2020-01-05") ... ): 2 ... })) a b c 2020-01-01 2.0 2.0 2.0 2020-01-02 NaN NaN NaN 2020-01-03 2.0 2.0 2.0 2020-01-04 2.0 2.0 2.0 2020-01-05 NaN NaN NaN ``` * Set column indices: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.colidx("a"): 2 ... })) a b c 2020-01-01 2.0 NaN NaN 2020-01-02 2.0 NaN NaN 2020-01-03 2.0 NaN NaN 2020-01-04 2.0 NaN NaN 2020-01-05 2.0 NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.colidx(("a", "b")): [1, 2] ... })) a b c 2020-01-01 1.0 2.0 NaN 2020-01-02 1.0 2.0 NaN 2020-01-03 1.0 2.0 NaN 2020-01-04 1.0 2.0 NaN 2020-01-05 1.0 2.0 NaN >>> multi_columns = pd.MultiIndex.from_arrays( ... [["a", "a", "b", "b"], [1, 2, 1, 2]], ... names=["c1", "c2"] ... ) >>> multi_wrapper = vbt.ArrayWrapper(index, multi_columns) >>> multi_wrapper.fill_and_set(vbt.index_dict({ ... vbt.colidx(("a", 2)): 2 ... })) c1 a b c2 1 2 1 2 2020-01-01 NaN 2.0 NaN NaN 2020-01-02 NaN 2.0 NaN NaN 2020-01-03 NaN 2.0 NaN NaN 2020-01-04 NaN 2.0 NaN NaN 2020-01-05 NaN 2.0 NaN NaN >>> multi_wrapper.fill_and_set(vbt.index_dict({ ... vbt.colidx("b", level="c1"): [3, 4] ... })) c1 a b c2 1 2 1 2 2020-01-01 NaN NaN 3.0 4.0 2020-01-02 NaN NaN 3.0 4.0 2020-01-03 NaN NaN 3.0 4.0 2020-01-04 NaN NaN 3.0 4.0 2020-01-05 NaN NaN 3.0 4.0 ``` * Set row and column indices: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.idx(2, 2): 2 ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 NaN NaN NaN 2020-01-03 NaN NaN 2.0 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.idx(("2020-01-01", "2020-01-03"), 2): [1, 2] ... })) a b c 2020-01-01 NaN NaN 1.0 2020-01-02 NaN NaN NaN 2020-01-03 NaN NaN 2.0 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.idx(("2020-01-01", "2020-01-03"), (0, 2)): [[1, 2], [3, 4]] ... })) a b c 2020-01-01 1.0 NaN 2.0 2020-01-02 NaN NaN NaN 2020-01-03 3.0 NaN 4.0 2020-01-04 NaN NaN NaN 2020-01-05 NaN NaN NaN >>> multi_wrapper.fill_and_set(vbt.index_dict({ ... vbt.idx( ... vbt.pointidx(every="2d"), ... vbt.colidx(1, level="c2") ... ): [[1, 2]] ... })) c1 a b c2 1 2 1 2 2020-01-01 1.0 NaN 2.0 NaN 2020-01-02 NaN NaN NaN NaN 2020-01-03 1.0 NaN 2.0 NaN 2020-01-04 NaN NaN NaN NaN 2020-01-05 1.0 NaN 2.0 NaN >>> multi_wrapper.fill_and_set(vbt.index_dict({ ... vbt.idx( ... vbt.pointidx(every="2d"), ... vbt.colidx(1, level="c2") ... ): [[1], [2], [3]] ... })) c1 a b c2 1 2 1 2 2020-01-01 1.0 NaN 1.0 NaN 2020-01-02 NaN NaN NaN NaN 2020-01-03 2.0 NaN 2.0 NaN 2020-01-04 NaN NaN NaN NaN 2020-01-05 3.0 NaN 3.0 NaN ``` * Set rows using a template: ```pycon >>> wrapper.fill_and_set(vbt.index_dict({ ... vbt.RepEval("index.day % 2 == 0"): 2 ... })) a b c 2020-01-01 NaN NaN NaN 2020-01-02 2.0 2.0 2.0 2020-01-03 NaN NaN NaN 2020-01-04 2.0 2.0 2.0 2020-01-05 NaN NaN NaN ``` """ if isinstance(idx_setter, index_dict): idx_setter = IdxDict(idx_setter) if isinstance(idx_setter, IdxSetterFactory): idx_setter = idx_setter.get() if not isinstance(idx_setter, IdxSetter): raise ValueError("Index setter factory must return exactly one index setter") checks.assert_instance_of(idx_setter, IdxSetter) arr = idx_setter.fill_and_set( self.shape, keep_flex=keep_flex, fill_value=fill_value, index=self.index, columns=self.columns, freq=self.freq, **kwargs, ) if not keep_flex: return self.wrap(arr, group_by=False) return arr WrappingT = tp.TypeVar("WrappingT", bound="Wrapping") class Wrapping(Configured, HasWrapper, IndexApplier, AttrResolverMixin): """Class that uses `ArrayWrapper` globally.""" @classmethod def resolve_row_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs: """Resolve keyword arguments for initializing `Wrapping` after stacking along rows.""" return kwargs @classmethod def resolve_column_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs: """Resolve keyword arguments for initializing `Wrapping` after stacking along columns.""" return kwargs @classmethod def resolve_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs: """Resolve keyword arguments for initializing `Wrapping` after stacking. Should be called after `Wrapping.resolve_row_stack_kwargs` or `Wrapping.resolve_column_stack_kwargs`.""" return cls.resolve_merge_kwargs(*[wrapping.config for wrapping in wrappings], **kwargs) @hybrid_method def row_stack( cls_or_self: tp.MaybeType[WrappingT], *objs: tp.MaybeTuple[WrappingT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> WrappingT: """Stack multiple `Wrapping` instances along rows. Should use `ArrayWrapper.row_stack`.""" raise NotImplementedError @hybrid_method def column_stack( cls_or_self: tp.MaybeType[WrappingT], *objs: tp.MaybeTuple[WrappingT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> WrappingT: """Stack multiple `Wrapping` instances along columns. Should use `ArrayWrapper.column_stack`.""" raise NotImplementedError def __init__(self, wrapper: ArrayWrapper, **kwargs) -> None: checks.assert_instance_of(wrapper, ArrayWrapper) self._wrapper = wrapper Configured.__init__(self, wrapper=wrapper, **kwargs) HasWrapper.__init__(self) AttrResolverMixin.__init__(self) def indexing_func(self: WrappingT, *args, **kwargs) -> WrappingT: """Perform indexing on `Wrapping`.""" new_wrapper = self.wrapper.indexing_func( *args, column_only_select=self.column_only_select, range_only_select=self.range_only_select, group_select=self.group_select, **kwargs, ) return self.replace(wrapper=new_wrapper) def resample(self: WrappingT, *args, **kwargs) -> WrappingT: """Perform resampling on `Wrapping`. When overriding, make sure to create a resampler by passing `*args` and `**kwargs` to `ArrayWrapper.get_resampler`.""" raise NotImplementedError @property def wrapper(self) -> ArrayWrapper: return self._wrapper def apply_to_index( self: ArrayWrapperT, apply_func: tp.Callable, *args, axis: tp.Optional[int] = None, **kwargs, ) -> ArrayWrapperT: if axis is None: axis = 0 if self.wrapper.ndim == 1 else 1 if self.wrapper.ndim == 1 and axis == 1: raise TypeError("Axis 1 is not supported for one dimension") checks.assert_in(axis, (0, 1)) if axis == 1: new_wrapper = self.wrapper.replace(columns=apply_func(self.wrapper.columns, *args, **kwargs)) else: new_wrapper = self.wrapper.replace(index=apply_func(self.wrapper.index, *args, **kwargs)) return self.replace(wrapper=new_wrapper) @property def column_only_select(self) -> bool: column_only_select = getattr(self, "_column_only_select", None) if column_only_select is None: return self.wrapper.column_only_select return column_only_select @property def range_only_select(self) -> bool: range_only_select = getattr(self, "_range_only_select", None) if range_only_select is None: return self.wrapper.range_only_select return range_only_select @property def group_select(self) -> bool: group_select = getattr(self, "_group_select", None) if group_select is None: return self.wrapper.group_select return group_select def regroup(self: WrappingT, group_by: tp.GroupByLike, **kwargs) -> WrappingT: """Regroup this instance. Only creates a new instance if grouping has changed, otherwise returns itself. `**kwargs` will be passed to `ArrayWrapper.regroup`.""" if self.wrapper.grouper.is_grouping_changed(group_by=group_by): self.wrapper.grouper.check_group_by(group_by=group_by) return self.replace(wrapper=self.wrapper.regroup(group_by, **kwargs)) return self # important for keeping cache def resolve_self( self: AttrResolverMixinT, cond_kwargs: tp.KwargsLike = None, custom_arg_names: tp.Optional[tp.Set[str]] = None, impacts_caching: bool = True, silence_warnings: tp.Optional[bool] = None, ) -> AttrResolverMixinT: """Resolve self. Creates a copy of this instance if a different `freq` can be found in `cond_kwargs`.""" from vectorbtpro._settings import settings wrapping_cfg = settings["wrapping"] if cond_kwargs is None: cond_kwargs = {} if custom_arg_names is None: custom_arg_names = set() if silence_warnings is None: silence_warnings = wrapping_cfg["silence_warnings"] if "freq" in cond_kwargs: wrapper_copy = self.wrapper.replace(freq=cond_kwargs["freq"]) if wrapper_copy.freq != self.wrapper.freq: if not silence_warnings: warn( f"Changing the frequency will create a copy of this instance. " f"Consider setting it upon instantiation to re-use existing cache." ) self_copy = self.replace(wrapper=wrapper_copy) for alias in self.self_aliases: if alias not in custom_arg_names: cond_kwargs[alias] = self_copy cond_kwargs["freq"] = self_copy.wrapper.freq if impacts_caching: cond_kwargs["use_caching"] = False return self_copy return self # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules with custom data classes.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.data.custom.alpaca import * from vectorbtpro.data.custom.av import * from vectorbtpro.data.custom.bento import * from vectorbtpro.data.custom.binance import * from vectorbtpro.data.custom.ccxt import * from vectorbtpro.data.custom.csv import * from vectorbtpro.data.custom.custom import * from vectorbtpro.data.custom.db import * from vectorbtpro.data.custom.duckdb import * from vectorbtpro.data.custom.feather import * from vectorbtpro.data.custom.file import * from vectorbtpro.data.custom.finpy import * from vectorbtpro.data.custom.gbm import * from vectorbtpro.data.custom.gbm_ohlc import * from vectorbtpro.data.custom.hdf import * from vectorbtpro.data.custom.local import * from vectorbtpro.data.custom.ndl import * from vectorbtpro.data.custom.parquet import * from vectorbtpro.data.custom.polygon import * from vectorbtpro.data.custom.random import * from vectorbtpro.data.custom.random_ohlc import * from vectorbtpro.data.custom.remote import * from vectorbtpro.data.custom.sql import * from vectorbtpro.data.custom.synthetic import * from vectorbtpro.data.custom.tv import * from vectorbtpro.data.custom.yf import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `AlpacaData`.""" import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.parsing import get_func_arg_names try: if not tp.TYPE_CHECKING: raise ImportError from alpaca.common.rest import RESTClient as AlpacaClientT except ImportError: AlpacaClientT = "AlpacaClient" __all__ = [ "AlpacaData", ] AlpacaDataT = tp.TypeVar("AlpacaDataT", bound="AlpacaData") class AlpacaData(RemoteData): """Data class for fetching from Alpaca. See https://github.com/alpacahq/alpaca-py for API. See `AlpacaData.fetch_symbol` for arguments. Usage: * Set up the API key globally (optional for crypto): ```pycon >>> from vectorbtpro import * >>> vbt.AlpacaData.set_custom_settings( ... client_config=dict( ... api_key="YOUR_KEY", ... secret_key="YOUR_SECRET" ... ) ... ) ``` * Pull stock data: ```pycon >>> data = vbt.AlpacaData.pull( ... "AAPL", ... start="2021-01-01", ... end="2022-01-01", ... timeframe="1 day" ... ) ``` * Pull crypto data: ```pycon >>> data = vbt.AlpacaData.pull( ... "BTC/USD", ... start="2021-01-01", ... end="2022-01-01", ... timeframe="1 day" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.alpaca") @classmethod def list_symbols( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, status: tp.Optional[str] = None, asset_class: tp.Optional[str] = None, exchange: tp.Optional[str] = None, trading_client: tp.Optional[AlpacaClientT] = None, client_config: tp.KwargsLike = None, ) -> tp.List[str]: """List all symbols. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`. Arguments `status`, `asset_class`, and `exchange` can be strings, such as `asset_class="crypto"`. For possible values, take a look into `alpaca.trading.enums`. !!! note If you get an authorization error, make sure that you either enable or disable the `paper` flag in `client_config` depending upon the account whose credentials you used. By default, the credentials are assumed to be of a live trading account (`paper=False`).""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("alpaca") from alpaca.trading.client import TradingClient from alpaca.trading.requests import GetAssetsRequest from alpaca.trading.enums import AssetStatus, AssetClass, AssetExchange if client_config is None: client_config = {} has_client_config = len(client_config) > 0 client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) if trading_client is None: arg_names = get_func_arg_names(TradingClient.__init__) client_config = {k: v for k, v in client_config.items() if k in arg_names} trading_client = TradingClient(**client_config) elif has_client_config: raise ValueError("Cannot apply client_config to already initialized client") if status is not None: if isinstance(status, str): status = getattr(AssetStatus, status.upper()) if asset_class is not None: if isinstance(asset_class, str): asset_class = getattr(AssetClass, asset_class.upper()) if exchange is not None: if isinstance(exchange, str): exchange = getattr(AssetExchange, exchange.upper()) search_params = GetAssetsRequest(status=status, asset_class=asset_class, exchange=exchange) assets = trading_client.get_all_assets(search_params) all_symbols = [] for asset in assets: symbol = asset.symbol if pattern is not None: if not cls.key_match(symbol, pattern, use_regex=use_regex): continue all_symbols.append(symbol) if sort: return sorted(dict.fromkeys(all_symbols)) return list(dict.fromkeys(all_symbols)) @classmethod def resolve_client( cls, client: tp.Optional[AlpacaClientT] = None, client_type: tp.Optional[str] = None, **client_config, ) -> AlpacaClientT: """Resolve the client. If provided, must be of the type `alpaca.data.historical.CryptoHistoricalDataClient` for `client_type="crypto"` and `alpaca.data.historical.StockHistoricalDataClient` for `client_type="stocks"`. Otherwise, will be created using `client_config`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("alpaca") from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient client = cls.resolve_custom_setting(client, "client") client_type = cls.resolve_custom_setting(client_type, "client_type") if client_config is None: client_config = {} has_client_config = len(client_config) > 0 client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) if client is None: if client_type == "crypto": arg_names = get_func_arg_names(CryptoHistoricalDataClient.__init__) client_config = {k: v for k, v in client_config.items() if k in arg_names} client = CryptoHistoricalDataClient(**client_config) elif client_type == "stocks": arg_names = get_func_arg_names(StockHistoricalDataClient.__init__) client_config = {k: v for k, v in client_config.items() if k in arg_names} client = StockHistoricalDataClient(**client_config) else: raise ValueError(f"Invalid client type: '{client_type}'") elif has_client_config: raise ValueError("Cannot apply client_config to already initialized client") return client @classmethod def fetch_symbol( cls, symbol: str, client: tp.Optional[AlpacaClientT] = None, client_type: tp.Optional[str] = None, client_config: tp.KwargsLike = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, adjustment: tp.Optional[str] = None, feed: tp.Optional[str] = None, limit: tp.Optional[int] = None, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Alpaca. Args: symbol (str): Symbol. client (alpaca.common.rest.RESTClient): Client. See `AlpacaData.resolve_client`. client_type (str): Client type. See `AlpacaData.resolve_client`. Determined automatically based on the symbol. Crypto symbols contain "/". client_config (dict): Client config. See `AlpacaData.resolve_client`. start (any): Start datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): End datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. adjustment (str): Specifies the corporate action adjustment for the returned bars. Options are: "raw", "split", "dividend" or "all". Default is "raw". feed (str): The feed to pull market data from. This is either "iex", "otc", or "sip". Feeds "sip" and "otc" are only available to those with a subscription. Default is "iex" for free plans and "sip" for paid. limit (int): The maximum number of returned items. For defaults, see `custom.alpaca` in `vectorbtpro._settings.data`. Global settings can be provided per exchange id using the `exchanges` dictionary. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("alpaca") from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient from alpaca.data.requests import CryptoBarsRequest, StockBarsRequest from alpaca.data.timeframe import TimeFrame, TimeFrameUnit if client_type is None: client_type = "crypto" if "/" in symbol else "stocks" if client_config is None: client_config = {} client = cls.resolve_client(client=client, client_type=client_type, **client_config) start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") adjustment = cls.resolve_custom_setting(adjustment, "adjustment") feed = cls.resolve_custom_setting(feed, "feed") limit = cls.resolve_custom_setting(limit, "limit") freq = timeframe split = dt.split_freq_str(timeframe) if split is None: raise ValueError(f"Invalid timeframe: '{timeframe}'") multiplier, unit = split if unit == "m": unit = TimeFrameUnit.Minute elif unit == "h": unit = TimeFrameUnit.Hour elif unit == "D": unit = TimeFrameUnit.Day elif unit == "W": unit = TimeFrameUnit.Week elif unit == "M": unit = TimeFrameUnit.Month else: raise ValueError(f"Invalid timeframe: '{timeframe}'") timeframe = TimeFrame(multiplier, unit) if start is not None: start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") start_str = start.replace(tzinfo=None).isoformat("T") else: start_str = None if end is not None: end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") end_str = end.replace(tzinfo=None).isoformat("T") else: end_str = None if isinstance(client, CryptoHistoricalDataClient): request = CryptoBarsRequest( symbol_or_symbols=symbol, timeframe=timeframe, start=start_str, end=end_str, limit=limit, ) df = client.get_crypto_bars(request).df elif isinstance(client, StockHistoricalDataClient): request = StockBarsRequest( symbol_or_symbols=symbol, timeframe=timeframe, start=start_str, end=end_str, limit=limit, adjustment=adjustment, feed=feed, ) df = client.get_stock_bars(request).df else: raise TypeError(f"Invalid client of type {type(client)}") df = df.droplevel("symbol", axis=0) df.index = df.index.rename("Open time") df.rename( columns={ "open": "Open", "high": "High", "low": "Low", "close": "Close", "volume": "Volume", "trade_count": "Trade count", "vwap": "VWAP", }, inplace=True, ) if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: df = df.tz_localize("utc") if "Open" in df.columns: df["Open"] = df["Open"].astype(float) if "High" in df.columns: df["High"] = df["High"].astype(float) if "Low" in df.columns: df["Low"] = df["Low"].astype(float) if "Close" in df.columns: df["Close"] = df["Close"].astype(float) if "Volume" in df.columns: df["Volume"] = df["Volume"].astype(float) if "Trade count" in df.columns: df["Trade count"] = df["Trade count"].astype(int, errors="ignore") if "VWAP" in df.columns: df["VWAP"] = df["VWAP"].astype(float) if not df.empty: if start is not None: start = dt.to_timestamp(start, tz=df.index.tz) if df.index[0] < start: df = df[df.index >= start] if end is not None: end = dt.to_timestamp(end, tz=df.index.tz) if df.index[-1] >= end: df = df[df.index < end] return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `AVData`.""" import re import urllib.parse from functools import lru_cache import numpy as np import pandas as pd import requests from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.module_ import check_installed from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.warnings_ import warn try: if not tp.TYPE_CHECKING: raise ImportError from alpha_vantage.alphavantage import AlphaVantage as AlphaVantageT except ImportError: AlphaVantageT = "AlphaVantage" __all__ = [ "AVData", ] __pdoc__ = {} AVDataT = tp.TypeVar("AVDataT", bound="AVData") class AVData(RemoteData): """Data class for fetching from Alpha Vantage. See https://www.alphavantage.co/documentation/ for API. Apart of using https://github.com/RomelTorres/alpha_vantage package, this class can also parse the API documentation with `AVData.parse_api_meta` using `BeautifulSoup4` and build the API query based on this metadata (pass `use_parser=True`). This approach is the most flexible we can get since we can instantly react to Alpha Vantage's changes in the API. If the data provider changes its API documentation, you can always adapt the parsing procedure by overriding `AVData.parse_api_meta`. If parser still fails, you can disable parsing entirely and specify all information manually by setting `function` and disabling `match_params` See `AVData.fetch_symbol` for arguments. Usage: * Set up the API key globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.AVData.set_custom_settings( ... apikey="YOUR_KEY" ... ) ``` * Pull data: ```pycon >>> data = vbt.AVData.pull( ... "GOOGL", ... timeframe="1 day", ... ) >>> data = vbt.AVData.pull( ... "BTC_USD", ... timeframe="30 minutes", # premium? ... category="digital-currency", ... outputsize="full" ... ) >>> data = vbt.AVData.pull( ... "REAL_GDP", ... category="economic-indicators" ... ) >>> data = vbt.AVData.pull( ... "IBM", ... category="technical-indicators", ... function="STOCHRSI", ... params=dict(fastkperiod=14) ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.av") @classmethod def list_symbols(cls, keywords: str, apikey: tp.Optional[str] = None, sort: bool = True) -> tp.List[str]: """List all symbols.""" apikey = cls.resolve_custom_setting(apikey, "apikey") query = dict() query["function"] = "SYMBOL_SEARCH" query["keywords"] = keywords query["datatype"] = "csv" query["apikey"] = apikey url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(query) df = pd.read_csv(url) if sort: return sorted(dict.fromkeys(df["symbol"].tolist())) return list(dict.fromkeys(df["symbol"].tolist())) @classmethod @lru_cache() def parse_api_meta(cls) -> dict: """Parse API metadata from the documentation at https://www.alphavantage.co/documentation Cached class method. To avoid re-parsing the same metadata in different runtimes, save it manually.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("bs4") from bs4 import BeautifulSoup page = requests.get("https://www.alphavantage.co/documentation") soup = BeautifulSoup(page.content, "html.parser") api_meta = {} for section in soup.select("article section"): category = {} function = None function_args = dict(req_args=set(), opt_args=set()) for tag in section.find_all(True): if tag.name == "h6": if function is not None and tag.select("b")[0].getText().strip() == "API Parameters": category[function] = function_args function = None function_args = dict(req_args=set(), opt_args=set()) if tag.name == "b": b_text = tag.getText().strip() if b_text.startswith("❚ Required"): arg = tag.select("code")[0].getText().strip() function_args["req_args"].add(arg) if tag.name == "p": p_text = tag.getText().strip() if p_text.startswith("❚ Optional"): arg = tag.select("code")[0].getText().strip() function_args["opt_args"].add(arg) if tag.name == "code": code_text = tag.getText().strip() if code_text.startswith("function="): function = code_text.replace("function=", "") if function is not None: category[function] = function_args api_meta[section.select("h2")[0]["id"]] = category return api_meta @classmethod def fetch_symbol( cls, symbol: str, use_parser: tp.Optional[bool] = None, apikey: tp.Optional[str] = None, api_meta: tp.Optional[dict] = None, category: tp.Union[None, str, AlphaVantageT, tp.Type[AlphaVantageT]] = None, function: tp.Union[None, str, tp.Callable] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, adjusted: tp.Optional[bool] = None, extended: tp.Optional[bool] = None, slice: tp.Optional[str] = None, series_type: tp.Optional[str] = None, time_period: tp.Optional[int] = None, outputsize: tp.Optional[str] = None, match_params: tp.Optional[bool] = None, params: tp.KwargsLike = None, read_csv_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.SymbolData: """Fetch a symbol from Alpha Vantage. If `use_parser` is False, or None and `alpha_vantage` is installed, uses the package. Otherwise, parses the API documentation and pulls data directly. See https://www.alphavantage.co/documentation/ for API endpoints and their parameters. !!! note Supports the CSV format only. Args: symbol (str): Symbol. May combine symbol/from_currency and market/to_currency using an underscore. use_parser (bool): Whether to use the parser instead of the `alpha_vantage` package. apikey (str): API key. api_meta (dict): API meta. If None, will use `AVData.parse_api_meta` if `function` is not provided or `match_params` is True. category (str or AlphaVantage): API category of your choice. Used if `function` is not provided or `match_params` is True. Supported are: * `alpha_vantage.alphavantage.AlphaVantage` instance, class, or class name * "time-series-data" or "time-series" * "fundamental-data" or "fundamentals" * "foreign-exchange", "forex", or "fx" * "digital-currency", "cryptocurrencies", "cryptocurrency", or "crypto" * "commodities" * "economic-indicators" * "technical-indicators" or "indicators" function (str or callable): API function of your choice. If None, will try to resolve it based on other arguments, such as `timeframe`, `adjusted`, and `extended`. Required for technical indicators, economic indicators, and fundamental data. See the keys in sub-dictionaries returned by `AVData.parse_api_meta`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". For time series, forex, and crypto, looks for interval type in the function's name. Defaults to "60min" if extended, otherwise to "daily". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. adjusted (bool): Whether to return time series adjusted by historical split and dividend events. extended (bool): Whether to return historical intraday time series for the trailing 2 years. slice (str): Slice of the trailing 2 years. series_type (str): The desired price type in the time series. time_period (int): Number of data points used to calculate each window value. outputsize (str): Output size. Supported are * "compact" that returns only the latest 100 data points * "full" that returns the full-length time series match_params (bool): Whether to match parameters with the ones required by the endpoint. Otherwise, uses only (resolved) `function`, `apikey`, `datatype="csv"`, and `params`. params: Additional keyword arguments passed as key/value pairs in the URL. read_csv_kwargs (dict): Keyword arguments passed to `pd.read_csv`. silence_warnings (bool): Whether to silence all warnings. For defaults, see `custom.av` in `vectorbtpro._settings.data`. """ use_parser = cls.resolve_custom_setting(use_parser, "use_parser") apikey = cls.resolve_custom_setting(apikey, "apikey") api_meta = cls.resolve_custom_setting(api_meta, "api_meta") category = cls.resolve_custom_setting(category, "category") function = cls.resolve_custom_setting(function, "function") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") adjusted = cls.resolve_custom_setting(adjusted, "adjusted") extended = cls.resolve_custom_setting(extended, "extended") slice = cls.resolve_custom_setting(slice, "slice") series_type = cls.resolve_custom_setting(series_type, "series_type") time_period = cls.resolve_custom_setting(time_period, "time_period") outputsize = cls.resolve_custom_setting(outputsize, "outputsize") read_csv_kwargs = cls.resolve_custom_setting(read_csv_kwargs, "read_csv_kwargs", merge=True) match_params = cls.resolve_custom_setting(match_params, "match_params") params = cls.resolve_custom_setting(params, "params", merge=True) silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings") if use_parser is None: if api_meta is None and check_installed("alpha_vantage"): use_parser = False else: use_parser = True if not use_parser: from vectorbtpro.utils.module_ import assert_can_import assert_can_import("alpha_vantage") if use_parser and api_meta is None and (function is None or match_params): if not silence_warnings and cls.parse_api_meta.cache_info().misses == 0: warn("Parsing API documentation...") try: api_meta = cls.parse_api_meta() except Exception as e: raise ValueError("Can't fetch/parse the API documentation. Specify function and disable match_params.") freq = timeframe interval = None interval_type = None if timeframe is not None: if not isinstance(timeframe, str): raise ValueError(f"Invalid timeframe: '{timeframe}'") split = dt.split_freq_str(timeframe) if split is None: raise ValueError(f"Invalid timeframe: '{timeframe}'") multiplier, unit = split if unit == "m": interval = str(multiplier) + "min" interval_type = "intraday" elif unit == "h": interval = str(60 * multiplier) + "min" interval_type = "intraday" elif unit == "D": interval = "daily" interval_type = "daily" elif unit == "W": interval = "weekly" interval_type = "weekly" elif unit == "M": interval = "monthly" interval_type = "monthly" elif unit == "Q": interval = "quarterly" interval_type = "quarterly" elif unit == "Y": interval = "annual" interval_type = "annual" if interval is None and multiplier > 1: raise ValueError("Multipliers are supported only for intraday timeframes") else: if extended: interval_type = "intraday" interval = "60min" else: interval_type = "daily" interval = "daily" if category is not None: if isinstance(category, str): if category.lower() in ("time-series-data", "time-series", "timeseries"): if use_parser: category = "time-series-data" else: from alpha_vantage.timeseries import TimeSeries category = TimeSeries elif category.lower() in ("fundamentals", "fundamental-data", "fundamentaldata"): if use_parser: category = "fundamentals" else: from alpha_vantage.fundamentaldata import FundamentalData category = FundamentalData elif category.lower() in ("fx", "forex", "foreign-exchange", "foreignexchange"): if use_parser: category = "fx" else: from alpha_vantage.foreignexchange import ForeignExchange category = ForeignExchange elif category.lower() in ("digital-currency", "cryptocurrencies", "cryptocurrency", "crypto"): if use_parser: category = "digital-currency" else: from alpha_vantage.cryptocurrencies import CryptoCurrencies category = CryptoCurrencies elif category.lower() in ("commodities",): if use_parser: category = "commodities" else: raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.") elif category.lower() in ("economic-indicators",): if use_parser: category = "economic-indicators" else: raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.") elif category.lower() in ("technical-indicators", "techindicators", "indicators"): if use_parser: category = "technical-indicators" else: from alpha_vantage.techindicators import TechIndicators category = TechIndicators else: raise ValueError(f"Invalid category: '{category}'") else: if use_parser: raise TypeError("Category must be a string") else: from alpha_vantage.alphavantage import AlphaVantage if isinstance(category, type): if not issubclass(category, AlphaVantage): raise TypeError("Category must be a subclass of AlphaVantage") elif not isinstance(category, AlphaVantage): raise TypeError("Category must be an instance of AlphaVantage") if use_parser: if function is None: if category is not None: if category in ("commodities", "economic-indicators"): function = symbol if function is None: if category is None: category = "time-series-data" if category in ("fundamentals", "technical-indicators"): raise ValueError("Function is required") adjusted_in_functions = False extended_in_functions = False matched_functions = [] for k in api_meta[category]: if interval_type is None or interval_type.upper() in k: if "ADJUSTED" in k: adjusted_in_functions = True if "EXTENDED" in k: extended_in_functions = True matched_functions.append(k) if adjusted_in_functions: matched_functions = [ k for k in matched_functions if (adjusted and "ADJUSTED" in k) or (not adjusted and "ADJUSTED" not in k) ] if extended_in_functions: matched_functions = [ k for k in matched_functions if (extended and "EXTENDED" in k) or (not extended and "EXTENDED" not in k) ] if len(matched_functions) == 0: raise ValueError("No functions satisfy the requirements") if len(matched_functions) > 1: raise ValueError("More than one function satisfies the requirements") function = matched_functions[0] if match_params: if function is not None and category is None: category = None for k, v in api_meta.items(): if function in v: category = k break if category is None: raise ValueError("Category is required") req_args = api_meta[category][function]["req_args"] opt_args = api_meta[category][function]["opt_args"] args = set(req_args) | set(opt_args) matched_params = dict() matched_params["function"] = function matched_params["datatype"] = "csv" matched_params["apikey"] = apikey if "symbol" in args and "market" in args: matched_params["symbol"] = symbol.split("_")[0] matched_params["market"] = symbol.split("_")[1] elif "from_" in args and "to_currency" in args: matched_params["from_currency"] = symbol.split("_")[0] matched_params["to_currency"] = symbol.split("_")[1] elif "from_currency" in args and "to_currency" in args: matched_params["from_currency"] = symbol.split("_")[0] matched_params["to_currency"] = symbol.split("_")[1] elif "symbol" in args: matched_params["symbol"] = symbol if "interval" in args: matched_params["interval"] = interval if "adjusted" in args: matched_params["adjusted"] = adjusted if "extended" in args: matched_params["extended"] = extended if "extended_hours" in args: matched_params["extended_hours"] = extended if "slice" in args: matched_params["slice"] = slice if "series_type" in args: matched_params["series_type"] = series_type if "time_period" in args: matched_params["time_period"] = time_period if "outputsize" in args: matched_params["outputsize"] = outputsize for k, v in params.items(): if k in args: matched_params[k] = v else: raise ValueError(f"Function '{function}' does not expect parameter '{k}'") for arg in req_args: if arg not in matched_params: raise ValueError(f"Function '{function}' requires parameter '{arg}'") else: matched_params = dict(params) matched_params["function"] = function matched_params["apikey"] = apikey matched_params["datatype"] = "csv" url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(matched_params) df = pd.read_csv(url, **read_csv_kwargs) else: from alpha_vantage.alphavantage import AlphaVantage from alpha_vantage.timeseries import TimeSeries from alpha_vantage.fundamentaldata import FundamentalData from alpha_vantage.foreignexchange import ForeignExchange from alpha_vantage.cryptocurrencies import CryptoCurrencies from alpha_vantage.techindicators import TechIndicators if isinstance(category, type) and issubclass(category, AlphaVantage): category = category(key=apikey, output_format="pandas") if function is None: if category is None: category = TimeSeries(key=apikey, output_format="pandas") if isinstance(category, (TechIndicators, FundamentalData)): raise ValueError("Function is required") adjusted_in_methods = False extended_in_methods = False matched_methods = [] for k in dir(category): if interval_type is None or interval_type in k: if "adjusted" in k: adjusted_in_methods = True if "extended" in k: extended_in_methods = True matched_methods.append(k) if adjusted_in_methods: matched_methods = [ k for k in matched_methods if (adjusted and "adjusted" in k) or (not adjusted and "adjusted" not in k) ] if extended_in_methods: matched_methods = [ k for k in matched_methods if (extended and "extended" in k) or (not extended and "extended" not in k) ] if len(matched_methods) == 0: raise ValueError("No methods satisfy the requirements") if len(matched_methods) > 1: raise ValueError("More than one method satisfies the requirements") function = matched_methods[0] if isinstance(function, str): function = function.lower() if not function.startswith("get_"): function = "get_" + function if category is not None: function = getattr(category, function) else: categories = [ TimeSeries, FundamentalData, ForeignExchange, CryptoCurrencies, TechIndicators, ] matched_methods = [] for category in categories: if function in dir(category): matched_methods.append(getattr(category, function)) if len(matched_methods) == 0: raise ValueError("No methods satisfy the requirements") if len(matched_methods) > 1: raise ValueError("More than one method satisfies the requirements") function = matched_methods[0] if match_params: args = set(get_func_arg_names(function)) matched_params = dict() if "symbol" in args and "market" in args: matched_params["symbol"] = symbol.split("_")[0] matched_params["market"] = symbol.split("_")[1] elif "from_" in args and "to_currency" in args: matched_params["from_currency"] = symbol.split("_")[0] matched_params["to_currency"] = symbol.split("_")[1] elif "from_currency" in args and "to_currency" in args: matched_params["from_currency"] = symbol.split("_")[0] matched_params["to_currency"] = symbol.split("_")[1] elif "symbol" in args: matched_params["symbol"] = symbol if "interval" in args: matched_params["interval"] = interval if "adjusted" in args: matched_params["adjusted"] = adjusted if "extended" in args: matched_params["extended"] = extended if "extended_hours" in args: matched_params["extended_hours"] = extended if "slice" in args: matched_params["slice"] = slice if "series_type" in args: matched_params["series_type"] = series_type if "time_period" in args: matched_params["time_period"] = time_period if "outputsize" in args: matched_params["outputsize"] = outputsize else: matched_params = dict(params) df, df_metadata = function(**matched_params) for k, v in df_metadata.items(): if "Time Zone" in k: if tz is None: if v.endswith(" Time"): v = v[: -len(" Time")] tz = v df.index.name = None new_columns = [] for c in df.columns: new_c = re.sub(r"^\d+\w*\.\s*", "", c) new_c = new_c[0].title() + new_c[1:] if new_c.endswith(" (USD)"): new_c = new_c[: -len(" (USD)")] new_columns.append(new_c) df = df.rename(columns=dict(zip(df.columns, new_columns))) df = df.loc[:, ~df.columns.duplicated()] for c in df.columns: if df[c].dtype == "O": df[c] = df[c].replace({".": np.nan}) df = df.apply(pd.to_numeric, errors="ignore") if not df.empty and df.index[0] > df.index[1]: df = df.iloc[::-1] if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None and tz is not None: df = df.tz_localize(tz) return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `BentoData`.""" from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.parsing import get_func_arg_names try: if not tp.TYPE_CHECKING: raise ImportError from databento import Historical as HistoricalT except ImportError: HistoricalT = "Historical" __all__ = [ "BentoData", ] class BentoData(RemoteData): """Data class for fetching from Databento. See https://github.com/databento/databento-python for API. See `BentoData.fetch_symbol` for arguments. Usage: * Set up the API key globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.BentoData.set_custom_settings( ... client_config=dict( ... key="YOUR_KEY" ... ) ... ) ``` * Pull data: ```pycon >>> data = vbt.BentoData.pull( ... "AAPL", ... dataset="XNAS.ITCH" ... ) ``` ```pycon >>> data = vbt.BentoData.pull( ... "AAPL", ... dataset="XNAS.ITCH", ... timeframe="hourly", ... start="one week ago" ... ) ``` ```pycon >>> data = vbt.BentoData.pull( ... "ES.FUT", ... dataset="GLBX.MDP3", ... stype_in="parent", ... schema="mbo", ... start="2022-06-10T14:30", ... end="2022-06-11", ... limit=1000 ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.bento") @classmethod def resolve_client(cls, client: tp.Optional[HistoricalT] = None, **client_config) -> HistoricalT: """Resolve the client. If provided, must be of the type `databento.historical.client.Historical`. Otherwise, will be created using `client_config`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("databento") from databento import Historical client = cls.resolve_custom_setting(client, "client") if client_config is None: client_config = {} has_client_config = len(client_config) > 0 client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) if client is None: client = Historical(**client_config) elif has_client_config: raise ValueError("Cannot apply client_config to already initialized client") return client @classmethod def get_cost(cls, symbols: tp.MaybeSymbols, **kwargs) -> float: """Get the cost of calling `BentoData.fetch_symbol` on one or more symbols.""" if isinstance(symbols, str): symbols = [symbols] costs = [] for symbol in symbols: client, params = cls.fetch_symbol(symbol, **kwargs, return_params=True) cost_arg_names = get_func_arg_names(client.metadata.get_cost) for k in list(params.keys()): if k not in cost_arg_names: del params[k] costs.append(client.metadata.get_cost(**params, mode="historical")) return sum(costs) @classmethod def fetch_symbol( cls, symbol: str, client: tp.Optional[HistoricalT] = None, client_config: tp.KwargsLike = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, resolve_dates: tp.Optional[bool] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, dataset: tp.Optional[str] = None, schema: tp.Optional[str] = None, return_params: bool = False, df_kwargs: tp.KwargsLike = None, **params, ) -> tp.Union[float, tp.SymbolData]: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Databento. Args: symbol (str): Symbol. Symbol can be in the `DATASET:SYMBOL` format if `dataset` is None. client (binance.client.Client): Client. See `BentoData.resolve_client`. client_config (dict): Client config. See `BentoData.resolve_client`. start (any): Start datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): End datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. resolve_dates (bool): Whether to resolve `start` and `end`, or pass them as they are. timeframe (str): Timeframe to create `schema` from. Allows human-readable strings such as "1 minute". If `timeframe` and `schema` are both not None, will raise an error. tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. dataset (str): See `databento.historical.client.Historical.get_range`. schema (str): See `databento.historical.client.Historical.get_range`. return_params (bool): Whether to return the client and (final) parameters instead of data. Used by `BentoData.get_cost`. df_kwargs (dict): Keyword arguments passed to `databento.common.dbnstore.DBNStore.to_df`. **params: Keyword arguments passed to `databento.historical.client.Historical.get_range`. For defaults, see `custom.bento` in `vectorbtpro._settings.data`. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("databento") if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") resolve_dates = cls.resolve_custom_setting(resolve_dates, "resolve_dates") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") dataset = cls.resolve_custom_setting(dataset, "dataset") schema = cls.resolve_custom_setting(schema, "schema") params = cls.resolve_custom_setting(params, "params", merge=True) df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True) if dataset is None: if ":" in symbol: dataset, symbol = symbol.split(":") if timeframe is None and schema is None: schema = "ohlcv-1d" freq = "1d" elif timeframe is not None: freq = timeframe split = dt.split_freq_str(timeframe) if split is not None: multiplier, unit = split timeframe = str(multiplier) + unit if schema is None or schema.lower() == "ohlcv": schema = f"ohlcv-{timeframe}" else: raise ValueError("Timeframe cannot be used together with schema") else: if schema.startswith("ohlcv-"): freq = schema[len("ohlcv-") :] else: freq = None if resolve_dates: dataset_range = client.metadata.get_dataset_range(dataset) if "start_date" in dataset_range: start_date = dt.to_tzaware_timestamp(dataset_range["start_date"], naive_tz="utc", tz="utc") else: start_date = dt.to_tzaware_timestamp(dataset_range["start"], naive_tz="utc", tz="utc") if "end_date" in dataset_range: end_date = dt.to_tzaware_timestamp(dataset_range["end_date"], naive_tz="utc", tz="utc") else: end_date = dt.to_tzaware_timestamp(dataset_range["end"], naive_tz="utc", tz="utc") if start is not None: start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz="utc") if start < start_date: start = start_date else: start = start_date if end is not None: end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz="utc") if end > end_date: end = end_date else: end = end_date if start.floor("d") == start: start = start.date().isoformat() else: start = start.isoformat() if end.floor("d") == end: end = end.date().isoformat() else: end = end.isoformat() params = merge_dicts( dict( dataset=dataset, start=start, end=end, symbols=symbol, schema=schema, ), params, ) if return_params: return client, params df = client.timeseries.get_range(**params).to_df(**df_kwargs) return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `BinanceData`.""" import time import traceback from functools import partial import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.generic import nb as generic_nb from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.pbar import ProgressBar from vectorbtpro.utils.warnings_ import warn try: if not tp.TYPE_CHECKING: raise ImportError from binance.client import Client as BinanceClientT except ImportError: BinanceClientT = "BinanceClient" __all__ = [ "BinanceData", ] __pdoc__ = {} BinanceDataT = tp.TypeVar("BinanceDataT", bound="BinanceData") class BinanceData(RemoteData): """Data class for fetching from Binance. See https://github.com/sammchardy/python-binance for API. See `BinanceData.fetch_symbol` for arguments. !!! note If you are using an exchange from the US, Japan or other TLD then make sure pass `tld="us"` in `client_config` when creating the client. Usage: * Set up the API key globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.BinanceData.set_custom_settings( ... client_config=dict( ... api_key="YOUR_KEY", ... api_secret="YOUR_SECRET" ... ) ... ) ``` * Pull data: ```pycon >>> data = vbt.BinanceData.pull( ... "BTCUSDT", ... start="2020-01-01", ... end="2021-01-01", ... timeframe="1 day" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.binance") _feature_config: tp.ClassVar[Config] = HybridConfig( { "Quote volume": dict( resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( resampler, generic_nb.sum_reduce_nb, ) ), "Taker base volume": dict( resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( resampler, generic_nb.sum_reduce_nb, ) ), "Taker quote volume": dict( resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( resampler, generic_nb.sum_reduce_nb, ) ), } ) @property def feature_config(self) -> Config: return self._feature_config @classmethod def resolve_client(cls, client: tp.Optional[BinanceClientT] = None, **client_config) -> BinanceClientT: """Resolve the client. If provided, must be of the type `binance.client.Client`. Otherwise, will be created using `client_config`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("binance") from binance.client import Client client = cls.resolve_custom_setting(client, "client") if client_config is None: client_config = {} has_client_config = len(client_config) > 0 client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) if client is None: client = Client(**client_config) elif has_client_config: raise ValueError("Cannot apply client_config to already initialized client") return client @classmethod def list_symbols( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, client: tp.Optional[BinanceClientT] = None, client_config: tp.KwargsLike = None, ) -> tp.List[str]: """List all symbols. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) all_symbols = [] for dct in client.get_exchange_info()["symbols"]: symbol = dct["symbol"] if pattern is not None: if not cls.key_match(symbol, pattern, use_regex=use_regex): continue all_symbols.append(symbol) if sort: return sorted(dict.fromkeys(all_symbols)) return list(dict.fromkeys(all_symbols)) @classmethod def fetch_symbol( cls, symbol: str, client: tp.Optional[BinanceClientT] = None, client_config: tp.KwargsLike = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, klines_type: tp.Union[None, int, str] = None, limit: tp.Optional[int] = None, delay: tp.Optional[float] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, **get_klines_kwargs, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Binance. Args: symbol (str): Symbol. client (binance.client.Client): Client. See `BinanceData.resolve_client`. client_config (dict): Client config. See `BinanceData.resolve_client`. start (any): Start datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): End datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. klines_type (int or str): Kline type. See `binance.enums.HistoricalKlinesType`. Supports strings. limit (int): The maximum number of returned items. delay (float): Time to sleep after each request (in seconds). show_progress (bool): Whether to show the progress bar. pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`. silence_warnings (bool): Whether to silence all warnings. **get_klines_kwargs: Keyword arguments passed to `binance.client.Client.get_klines`. For defaults, see `custom.binance` in `vectorbtpro._settings.data`. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("binance") from binance.enums import HistoricalKlinesType if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") klines_type = cls.resolve_custom_setting(klines_type, "klines_type") if isinstance(klines_type, str): klines_type = map_enum_fields(klines_type, HistoricalKlinesType) if isinstance(klines_type, int): klines_type = {i.value: i for i in HistoricalKlinesType}[klines_type] limit = cls.resolve_custom_setting(limit, "limit") delay = cls.resolve_custom_setting(delay, "delay") show_progress = cls.resolve_custom_setting(show_progress, "show_progress") pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True) if "bar_id" not in pbar_kwargs: pbar_kwargs["bar_id"] = "binance" silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings") get_klines_kwargs = cls.resolve_custom_setting(get_klines_kwargs, "get_klines_kwargs", merge=True) # Prepare parameters freq = timeframe split = dt.split_freq_str(timeframe) if split is not None: multiplier, unit = split if unit == "D": unit = "d" elif unit == "W": unit = "w" timeframe = str(multiplier) + unit if start is not None: start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) first_valid_ts = client._get_earliest_valid_timestamp(symbol, timeframe, klines_type) start_ts = max(start_ts, first_valid_ts) else: start_ts = None prev_end_ts = None if end is not None: end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) else: end_ts = None def _ts_to_str(ts: tp.Optional[int]) -> str: if ts is None: return "?" return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe) def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool: if start_ts is not None: if d[0] < start_ts: return False if _prev_end_ts is not None: if d[0] <= _prev_end_ts: return False if end_ts is not None: if d[0] >= end_ts: return False return True # Iteratively collect the data data = [] try: with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar: pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts))) while True: # Fetch the klines for the next timeframe next_data = client._klines( symbol=symbol, interval=timeframe, limit=limit, startTime=start_ts if prev_end_ts is None else prev_end_ts, endTime=end_ts, klines_type=klines_type, **get_klines_kwargs, ) next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data)) # Update the timestamps and the progress bar if not len(next_data): break data += next_data if start_ts is None: start_ts = next_data[0][0] pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0]))) pbar.update() prev_end_ts = next_data[-1][0] if end_ts is not None and prev_end_ts >= end_ts: break if delay is not None: time.sleep(delay) # be kind to api except Exception as e: if not silence_warnings: warn(traceback.format_exc()) warn( f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. " "Use update() method to fetch missing data." ) # Convert data to a DataFrame df = pd.DataFrame( data, columns=[ "Open time", "Open", "High", "Low", "Close", "Volume", "Close time", "Quote volume", "Trade count", "Taker base volume", "Taker quote volume", "Ignore", ], ) df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True) df["Open"] = df["Open"].astype(float) df["High"] = df["High"].astype(float) df["Low"] = df["Low"].astype(float) df["Close"] = df["Close"].astype(float) df["Volume"] = df["Volume"].astype(float) df["Quote volume"] = df["Quote volume"].astype(float) df["Trade count"] = df["Trade count"].astype(int, errors="ignore") df["Taker base volume"] = df["Taker base volume"].astype(float) df["Taker quote volume"] = df["Taker quote volume"].astype(float) del df["Open time"] del df["Close time"] del df["Ignore"] return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) BinanceData.override_feature_config_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `CCXTData`.""" import time import traceback from functools import wraps, partial import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.pbar import ProgressBar from vectorbtpro.utils.warnings_ import warn try: if not tp.TYPE_CHECKING: raise ImportError from ccxt.base.exchange import Exchange as CCXTExchangeT except ImportError: CCXTExchangeT = "CCXTExchange" __all__ = [ "CCXTData", ] __pdoc__ = {} class CCXTData(RemoteData): """Data class for fetching using CCXT. See https://github.com/ccxt/ccxt for API. See `CCXTData.fetch_symbol` for arguments. Usage: * Set up the API key globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.CCXTData.set_exchange_settings( ... exchange_name="binance", ... populate_=True, ... exchange_config=dict( ... apiKey="YOUR_KEY", ... secret="YOUR_SECRET" ... ) ... ) ``` * Pull data: ```pycon >>> data = vbt.CCXTData.pull( ... "BTCUSDT", ... exchange="binance", ... start="2020-01-01", ... end="2021-01-01", ... timeframe="1 day" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.ccxt") @classmethod def get_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> dict: """`CCXTData.get_custom_settings` with `sub_path=exchange_name`.""" if exchange_name is not None: sub_path = "exchanges." + exchange_name else: sub_path = None return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs) @classmethod def has_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool: """`CCXTData.has_custom_settings` with `sub_path=exchange_name`.""" if exchange_name is not None: sub_path = "exchanges." + exchange_name else: sub_path = None return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs) @classmethod def get_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any: """`CCXTData.get_custom_setting` with `sub_path=exchange_name`.""" if exchange_name is not None: sub_path = "exchanges." + exchange_name else: sub_path = None return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs) @classmethod def has_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool: """`CCXTData.has_custom_setting` with `sub_path=exchange_name`.""" if exchange_name is not None: sub_path = "exchanges." + exchange_name else: sub_path = None return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs) @classmethod def resolve_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any: """`CCXTData.resolve_custom_setting` with `sub_path=exchange_name`.""" if exchange_name is not None: sub_path = "exchanges." + exchange_name else: sub_path = None return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs) @classmethod def set_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> None: """`CCXTData.set_custom_settings` with `sub_path=exchange_name`.""" if exchange_name is not None: sub_path = "exchanges." + exchange_name else: sub_path = None cls.set_custom_settings(*args, sub_path=sub_path, **kwargs) @classmethod def list_symbols( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, exchange: tp.Union[None, str, CCXTExchangeT] = None, exchange_config: tp.Optional[tp.KwargsLike] = None, ) -> tp.List[str]: """List all symbols. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" if exchange_config is None: exchange_config = {} exchange = cls.resolve_exchange(exchange=exchange, **exchange_config) all_symbols = [] for symbol in exchange.load_markets(): if pattern is not None: if not cls.key_match(symbol, pattern, use_regex=use_regex): continue all_symbols.append(symbol) if sort: return sorted(dict.fromkeys(all_symbols)) return list(dict.fromkeys(all_symbols)) @classmethod def resolve_exchange( cls, exchange: tp.Union[None, str, CCXTExchangeT] = None, **exchange_config, ) -> CCXTExchangeT: """Resolve the exchange. If provided, must be of the type `ccxt.base.exchange.Exchange`. Otherwise, will be created using `exchange_config`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("ccxt") import ccxt exchange = cls.resolve_exchange_setting(exchange, "exchange") if exchange is None: exchange = "binance" if isinstance(exchange, str): exchange = exchange.lower() exchange_name = exchange elif isinstance(exchange, ccxt.Exchange): exchange_name = type(exchange).__name__ else: raise ValueError(f"Unknown exchange of type {type(exchange)}") if exchange_config is None: exchange_config = {} has_exchange_config = len(exchange_config) > 0 exchange_config = cls.resolve_exchange_setting( exchange_config, "exchange_config", merge=True, exchange_name=exchange_name ) if isinstance(exchange, str): if not hasattr(ccxt, exchange): raise ValueError(f"Exchange '{exchange}' not found in CCXT") exchange = getattr(ccxt, exchange)(exchange_config) else: if has_exchange_config: raise ValueError("Cannot apply config after instantiation of the exchange") return exchange @staticmethod def _find_earliest_date( fetch_func: tp.Callable, start: tp.DatetimeLike = 0, end: tp.DatetimeLike = "now", tz: tp.TimezoneLike = None, for_internal_use: bool = False, ) -> tp.Optional[pd.Timestamp]: """Find the earliest date using binary search.""" if start is not None: start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) fetched_data = fetch_func(start_ts, 1) if for_internal_use and len(fetched_data) > 0: return pd.Timestamp(start_ts, unit="ms", tz="utc") else: fetched_data = [] if len(fetched_data) == 0 and start != 0: fetched_data = fetch_func(0, 1) if for_internal_use and len(fetched_data) > 0: return pd.Timestamp(0, unit="ms", tz="utc") if len(fetched_data) == 0: if start is not None: start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) else: start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(0, naive_tz=tz, tz="utc")) start_ts = start_ts - start_ts % 86400000 if end is not None: end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) else: end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime("now", naive_tz=tz, tz="utc")) end_ts = end_ts - end_ts % 86400000 + 86400000 start_time = start_ts end_time = end_ts while True: mid_time = (start_time + end_time) // 2 mid_time = mid_time - mid_time % 86400000 if mid_time == start_time: break _fetched_data = fetch_func(mid_time, 1) if len(_fetched_data) == 0: start_time = mid_time else: end_time = mid_time fetched_data = _fetched_data if len(fetched_data) > 0: return pd.Timestamp(fetched_data[0][0], unit="ms", tz="utc") return None @classmethod def find_earliest_date(cls, symbol: str, for_internal_use: bool = False, **kwargs) -> tp.Optional[pd.Timestamp]: """Find the earliest date using binary search. See `CCXTData.fetch_symbol` for arguments.""" return cls._find_earliest_date( **cls.fetch_symbol(symbol, return_fetch_method=True, **kwargs), for_internal_use=for_internal_use, ) @classmethod def fetch_symbol( cls, symbol: str, exchange: tp.Union[None, str, CCXTExchangeT] = None, exchange_config: tp.Optional[tp.KwargsLike] = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, find_earliest_date: tp.Optional[bool] = None, limit: tp.Optional[int] = None, delay: tp.Optional[float] = None, retries: tp.Optional[int] = None, fetch_params: tp.Optional[tp.KwargsLike] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, return_fetch_method: bool = False, ) -> tp.Union[dict, tp.SymbolData]: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from CCXT. Args: symbol (str): Symbol. Symbol can be in the `EXCHANGE:SYMBOL` format, in this case `exchange` argument will be ignored. exchange (str or object): Exchange identifier or an exchange object. See `CCXTData.resolve_exchange`. exchange_config (dict): Exchange config. See `CCXTData.resolve_exchange`. start (any): Start datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): End datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. find_earliest_date (bool): Whether to find the earliest date using `CCXTData.find_earliest_date`. limit (int): The maximum number of returned items. delay (float): Time to sleep after each request (in seconds). !!! note Use only if `enableRateLimit` is not set. retries (int): The number of retries on failure to fetch data. fetch_params (dict): Exchange-specific keyword arguments passed to `fetch_ohlcv`. show_progress (bool): Whether to show the progress bar. pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`. silence_warnings (bool): Whether to silence all warnings. return_fetch_method (bool): Required by `CCXTData.find_earliest_date`. For defaults, see `custom.ccxt` in `vectorbtpro._settings.data`. Global settings can be provided per exchange id using the `exchanges` dictionary. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("ccxt") import ccxt exchange = cls.resolve_custom_setting(exchange, "exchange") if exchange is None and ":" in symbol: exchange, symbol = symbol.split(":") if exchange_config is None: exchange_config = {} exchange = cls.resolve_exchange(exchange=exchange, **exchange_config) exchange_name = type(exchange).__name__ start = cls.resolve_exchange_setting(start, "start", exchange_name=exchange_name) end = cls.resolve_exchange_setting(end, "end", exchange_name=exchange_name) timeframe = cls.resolve_exchange_setting(timeframe, "timeframe", exchange_name=exchange_name) tz = cls.resolve_exchange_setting(tz, "tz", exchange_name=exchange_name) find_earliest_date = cls.resolve_exchange_setting( find_earliest_date, "find_earliest_date", exchange_name=exchange_name ) limit = cls.resolve_exchange_setting(limit, "limit", exchange_name=exchange_name) delay = cls.resolve_exchange_setting(delay, "delay", exchange_name=exchange_name) retries = cls.resolve_exchange_setting(retries, "retries", exchange_name=exchange_name) fetch_params = cls.resolve_exchange_setting( fetch_params, "fetch_params", merge=True, exchange_name=exchange_name ) show_progress = cls.resolve_exchange_setting(show_progress, "show_progress", exchange_name=exchange_name) pbar_kwargs = cls.resolve_exchange_setting(pbar_kwargs, "pbar_kwargs", merge=True, exchange_name=exchange_name) if "bar_id" not in pbar_kwargs: pbar_kwargs["bar_id"] = "ccxt" silence_warnings = cls.resolve_exchange_setting( silence_warnings, "silence_warnings", exchange_name=exchange_name ) if not exchange.has["fetchOHLCV"]: raise ValueError(f"Exchange {exchange} does not support OHLCV") if exchange.has["fetchOHLCV"] == "emulated": if not silence_warnings: warn("Using emulated OHLCV candles") freq = timeframe split = dt.split_freq_str(timeframe) if split is not None: multiplier, unit = split if unit == "D": unit = "d" elif unit == "W": unit = "w" elif unit == "Y": unit = "y" timeframe = str(multiplier) + unit if timeframe not in exchange.timeframes: raise ValueError(f"Exchange {exchange} does not support {timeframe} timeframe") def _retry(method): @wraps(method) def retry_method(*args, **kwargs): for i in range(retries): try: return method(*args, **kwargs) except ccxt.NetworkError as e: if i == retries - 1: raise e if not silence_warnings: warn(traceback.format_exc()) if delay is not None: time.sleep(delay) return retry_method @_retry def _fetch(_since, _limit): return exchange.fetch_ohlcv( symbol, timeframe=timeframe, since=_since, limit=_limit, params=fetch_params, ) if return_fetch_method: return dict(fetch_func=_fetch, start=start, end=end, tz=tz) # Establish the timestamps if find_earliest_date and start is not None: start = cls._find_earliest_date(_fetch, start=start, end=end, tz=tz, for_internal_use=True) if start is not None: start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) else: start_ts = None if end is not None: end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="UTC")) else: end_ts = None prev_end_ts = None def _ts_to_str(ts: tp.Optional[int]) -> str: if ts is None: return "?" return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe) def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool: if start_ts is not None: if d[0] < start_ts: return False if _prev_end_ts is not None: if d[0] <= _prev_end_ts: return False if end_ts is not None: if d[0] >= end_ts: return False return True # Iteratively collect the data data = [] try: with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar: pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts))) while True: # Fetch the klines for the next timeframe next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit) next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data)) # Update the timestamps and the progress bar if not len(next_data): break data += next_data if start_ts is None: start_ts = next_data[0][0] pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0]))) pbar.update() prev_end_ts = next_data[-1][0] if end_ts is not None and prev_end_ts >= end_ts: break if delay is not None: time.sleep(delay) # be kind to api except Exception as e: if not silence_warnings: warn(traceback.format_exc()) warn( f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. " "Use update() method to fetch missing data." ) # Convert data to a DataFrame df = pd.DataFrame(data, columns=["Open time", "Open", "High", "Low", "Close", "Volume"]) df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True) del df["Open time"] if "Open" in df.columns: df["Open"] = df["Open"].astype(float) if "High" in df.columns: df["High"] = df["High"].astype(float) if "Low" in df.columns: df["Low"] = df["Low"].astype(float) if "Close" in df.columns: df["Close"] = df["Close"].astype(float) if "Volume" in df.columns: df["Volume"] = df["Volume"].astype(float) return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `CSVData`.""" from pathlib import Path import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.file import FileData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts __all__ = [ "CSVData", ] __pdoc__ = {} CSVDataT = tp.TypeVar("CSVDataT", bound="CSVData") class CSVData(FileData): """Data class for fetching CSV data.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.csv") @classmethod def is_csv_file(cls, path: tp.PathLike) -> bool: """Return whether the path is a CSV/TSV file.""" if not isinstance(path, Path): path = Path(path) if path.exists() and path.is_file() and ".csv" in path.suffixes: return True if path.exists() and path.is_file() and ".tsv" in path.suffixes: return True return False @classmethod def is_file_match(cls, path: tp.PathLike) -> bool: return cls.is_csv_file(path) @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, paths: tp.Any = None, ) -> tp.Kwargs: keys_meta = FileData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) if keys_meta["keys"] is None and paths is None: keys_meta["keys"] = cls.list_paths() return keys_meta @classmethod def fetch_key( cls, key: tp.Key, path: tp.Any = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, tz: tp.TimezoneLike = None, start_row: tp.Optional[int] = None, end_row: tp.Optional[int] = None, header: tp.Optional[tp.MaybeSequence[int]] = None, index_col: tp.Optional[int] = None, parse_dates: tp.Optional[bool] = None, chunk_func: tp.Optional[tp.Callable] = None, squeeze: tp.Optional[bool] = None, **read_kwargs, ) -> tp.KeyData: """Fetch the CSV file of a feature or symbol. Args: key (hashable): Feature or symbol. path (str): Path. If `path` is None, uses `key` as the path to the CSV file. start (any): Start datetime. Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. end (any): End datetime. Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. tz (any): Target timezone. See `vectorbtpro.utils.datetime_.to_timezone`. start_row (int): Start row (inclusive). Must exclude header rows. end_row (int): End row (exclusive). Must exclude header rows. header (int or sequence of int): See `pd.read_csv`. index_col (int): See `pd.read_csv`. If False, will pass None. parse_dates (bool): See `pd.read_csv`. chunk_func (callable): Function to select and concatenate chunks from `TextFileReader`. Gets called only if `iterator` or `chunksize` are set. squeeze (int): Whether to squeeze a DataFrame with one column into a Series. **read_kwargs: Other keyword arguments passed to `pd.read_csv`. `skiprows` and `nrows` will be automatically calculated based on `start_row` and `end_row`. When either `start` or `end` is provided, will fetch the entire data first and filter it thereafter. See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for other arguments. For defaults, see `custom.csv` in `vectorbtpro._settings.data`.""" from pandas.io.parsers import TextFileReader from pandas.api.types import is_object_dtype start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") tz = cls.resolve_custom_setting(tz, "tz") start_row = cls.resolve_custom_setting(start_row, "start_row") if start_row is None: start_row = 0 end_row = cls.resolve_custom_setting(end_row, "end_row") header = cls.resolve_custom_setting(header, "header") index_col = cls.resolve_custom_setting(index_col, "index_col") if index_col is False: index_col = None parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates") chunk_func = cls.resolve_custom_setting(chunk_func, "chunk_func") squeeze = cls.resolve_custom_setting(squeeze, "squeeze") read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) if path is None: path = key if isinstance(header, int): header = [header] header_rows = header[-1] + 1 start_row += header_rows if end_row is not None: end_row += header_rows skiprows = range(header_rows, start_row) if end_row is not None: nrows = end_row - start_row else: nrows = None sep = read_kwargs.pop("sep", None) if isinstance(path, (str, Path)): try: _path = Path(path) if _path.suffix.lower() == ".csv": if sep is None: sep = "," if _path.suffix.lower() == ".tsv": if sep is None: sep = "\t" except Exception as e: pass if sep is None: sep = "," obj = pd.read_csv( path, sep=sep, header=header, index_col=index_col, parse_dates=parse_dates, skiprows=skiprows, nrows=nrows, **read_kwargs, ) if isinstance(obj, TextFileReader): if chunk_func is None: obj = pd.concat(list(obj), axis=0) else: obj = chunk_func(obj) if isinstance(obj, pd.DataFrame) and squeeze: obj = obj.squeeze("columns") if isinstance(obj, pd.Series) and obj.name == "0": obj.name = None if index_col is not None and parse_dates and is_object_dtype(obj.index.dtype): obj.index = pd.to_datetime(obj.index, utc=True) if tz is not None: obj.index = obj.index.tz_convert(tz) if isinstance(obj.index, pd.DatetimeIndex) and tz is None: tz = obj.index.tz if start is not None or end is not None: if not isinstance(obj.index, pd.DatetimeIndex): raise TypeError("Cannot filter index that is not DatetimeIndex") if obj.index.tz is not None: if start is not None: start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=obj.index.tz) if end is not None: end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=obj.index.tz) else: if start is not None: start = dt.to_naive_timestamp(start, tz=tz) if end is not None: end = dt.to_naive_timestamp(end, tz=tz) mask = True if start is not None: mask &= obj.index >= start if end is not None: mask &= obj.index < end mask_indices = np.flatnonzero(mask) if len(mask_indices) == 0: return None obj = obj.iloc[mask_indices[0] : mask_indices[-1] + 1] start_row += mask_indices[0] return obj, dict(last_row=start_row - header_rows + len(obj.index) - 1, tz=tz) @classmethod def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Fetch the CSV file of a feature. Uses `CSVData.fetch_key`.""" return cls.fetch_key(feature, **kwargs) @classmethod def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Fetch the CSV file of a symbol. Uses `CSVData.fetch_key`.""" return cls.fetch_key(symbol, **kwargs) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) returned_kwargs = self.select_returned_kwargs(key) fetch_kwargs["start_row"] = returned_kwargs["last_row"] kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `CSVData.update_key` with `key_is_feature=True`.""" return self.update_key(feature, key_is_feature=True, **kwargs) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `CSVData.update_key` with `key_is_feature=False`.""" return self.update_key(symbol, key_is_feature=False, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `CustomData`.""" import fnmatch import re from vectorbtpro import _typing as tp from vectorbtpro.data.base import Data __all__ = [ "CustomData", ] __pdoc__ = {} class CustomData(Data): """Data class for fetching custom data.""" _settings_path: tp.SettingsPath = dict(custom=None) @classmethod def get_custom_settings(cls, *args, **kwargs) -> dict: """`CustomData.get_settings` with `path_id="custom"`.""" return cls.get_settings(*args, path_id="custom", **kwargs) @classmethod def has_custom_settings(cls, *args, **kwargs) -> bool: """`CustomData.has_settings` with `path_id="custom"`.""" return cls.has_settings(*args, path_id="custom", **kwargs) @classmethod def get_custom_setting(cls, *args, **kwargs) -> tp.Any: """`CustomData.get_setting` with `path_id="custom"`.""" return cls.get_setting(*args, path_id="custom", **kwargs) @classmethod def has_custom_setting(cls, *args, **kwargs) -> bool: """`CustomData.has_setting` with `path_id="custom"`.""" return cls.has_setting(*args, path_id="custom", **kwargs) @classmethod def resolve_custom_setting(cls, *args, **kwargs) -> tp.Any: """`CustomData.resolve_setting` with `path_id="custom"`.""" return cls.resolve_setting(*args, path_id="custom", **kwargs) @classmethod def set_custom_settings(cls, *args, **kwargs) -> None: """`CustomData.set_settings` with `path_id="custom"`.""" cls.set_settings(*args, path_id="custom", **kwargs) @staticmethod def key_match(key: str, pattern: str, use_regex: bool = False): """Return whether key matches pattern. If `use_regex` is True, checks against a regular expression. Otherwise, checks against a glob-style pattern.""" if use_regex: return re.match(pattern, key) return re.match(fnmatch.translate(pattern), key) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `DBData`.""" from vectorbtpro import _typing as tp from vectorbtpro.data.custom.local import LocalData __all__ = [ "DBData", ] __pdoc__ = {} class DBData(LocalData): """Data class for fetching database data.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.db") # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `DuckDBData`.""" from pathlib import Path import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.base import key_dict from vectorbtpro.data.custom.db import DBData from vectorbtpro.data.custom.file import FileData from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.config import merge_dicts try: if not tp.TYPE_CHECKING: raise ImportError from duckdb import DuckDBPyConnection as DuckDBPyConnectionT, DuckDBPyRelation as DuckDBPyRelationT except ImportError: DuckDBPyConnectionT = "DuckDBPyConnection" DuckDBPyRelationT = "DuckDBPyRelation" __all__ = [ "DuckDBData", ] __pdoc__ = {} DuckDBDataT = tp.TypeVar("DuckDBDataT", bound="DuckDBData") class DuckDBData(DBData): """Data class for fetching data using DuckDB. See `DuckDBData.pull` and `DuckDBData.fetch_key` for arguments. Usage: * Set up the connection settings globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.DuckDBData.set_custom_settings(connection="database.duckdb") ``` * Pull tables: ```pycon >>> data = vbt.DuckDBData.pull(["TABLE1", "TABLE2"]) ``` * Rename tables: ```pycon >>> data = vbt.DuckDBData.pull( ... ["SYMBOL1", "SYMBOL2"], ... table=vbt.key_dict({ ... "SYMBOL1": "TABLE1", ... "SYMBOL2": "TABLE2" ... }) ... ) ``` * Pull queries: ```pycon >>> data = vbt.DuckDBData.pull( ... ["SYMBOL1", "SYMBOL2"], ... query=vbt.key_dict({ ... "SYMBOL1": "SELECT * FROM TABLE1", ... "SYMBOL2": "SELECT * FROM TABLE2" ... }) ... ) ``` * Pull Parquet files: ```pycon >>> data = vbt.DuckDBData.pull( ... ["SYMBOL1", "SYMBOL2"], ... read_path=vbt.key_dict({ ... "SYMBOL1": "s1.parquet", ... "SYMBOL2": "s2.parquet" ... }) ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.duckdb") @classmethod def resolve_connection( cls, connection: tp.Union[None, str, tp.PathLike, DuckDBPyConnectionT] = None, read_only: bool = True, return_meta: bool = False, **connection_config, ) -> tp.Union[DuckDBPyConnectionT, dict]: """Resolve the connection.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("duckdb") from duckdb import connect, default_connection connection_meta = {} connection = cls.resolve_custom_setting(connection, "connection") if connection_config is None: connection_config = {} has_connection_config = len(connection_config) > 0 connection_config["read_only"] = read_only connection_config = cls.resolve_custom_setting(connection_config, "connection_config", merge=True) read_only = connection_config.pop("read_only", read_only) should_close = False if connection is None: if len(connection_config) == 0: connection = default_connection else: database = connection_config.pop("database", None) if "config" in connection_config or len(connection_config) == 0: connection = connect(database, read_only=read_only, **connection_config) else: connection = connect(database, read_only=read_only, config=connection_config) should_close = True elif isinstance(connection, (str, Path)): if "config" in connection_config or len(connection_config) == 0: connection = connect(str(connection), read_only=read_only, **connection_config) else: connection = connect(str(connection), read_only=read_only, config=connection_config) should_close = True elif has_connection_config: raise ValueError("Cannot apply connection_config to already initialized connection") if return_meta: connection_meta["connection"] = connection connection_meta["should_close"] = should_close return connection_meta return connection @classmethod def list_catalogs( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, incl_system: bool = False, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, ) -> tp.List[str]: """List all catalogs. Catalogs "system" and "temp" are skipped if `incl_system` is False. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" if connection_config is None: connection_config = {} connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) connection = connection_meta["connection"] should_close = connection_meta["should_close"] schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df() catalogs = [] for catalog in schemata_df["catalog_name"].tolist(): if pattern is not None: if not cls.key_match(catalog, pattern, use_regex=use_regex): continue if not incl_system and catalog == "system": continue if not incl_system and catalog == "temp": continue catalogs.append(catalog) if should_close: connection.close() if sort: return sorted(dict.fromkeys(catalogs)) return list(dict.fromkeys(catalogs)) @classmethod def list_schemas( cls, catalog_pattern: tp.Optional[str] = None, schema_pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, catalog: tp.Optional[str] = None, incl_system: bool = False, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, ) -> tp.List[str]: """List all schemas. If `catalog` is None, searches for all catalog names in the database and prefixes each schema with the respective catalog name. If `catalog` is provided, returns the schemas corresponding to this catalog without a prefix. Schemas "information_schema" and "pg_catalog" are skipped if `incl_system` is False. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" if connection_config is None: connection_config = {} connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) connection = connection_meta["connection"] should_close = connection_meta["should_close"] if catalog is None: catalogs = cls.list_catalogs( pattern=catalog_pattern, use_regex=use_regex, sort=sort, incl_system=incl_system, connection=connection, connection_config=connection_config, ) if len(catalogs) == 1: prefix_catalog = False else: prefix_catalog = True else: catalogs = [catalog] prefix_catalog = False schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df() schemas = [] for catalog in catalogs: all_schemas = schemata_df[schemata_df["catalog_name"] == catalog]["schema_name"].tolist() for schema in all_schemas: if schema_pattern is not None: if not cls.key_match(schema, schema_pattern, use_regex=use_regex): continue if not incl_system and schema == "information_schema": continue if not incl_system and schema == "pg_catalog": continue if prefix_catalog: schema = catalog + ":" + schema schemas.append(schema) if should_close: connection.close() if sort: return sorted(dict.fromkeys(schemas)) return list(dict.fromkeys(schemas)) @classmethod def get_current_schema( cls, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, ) -> str: """Get the current schema.""" if connection_config is None: connection_config = {} connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) connection = connection_meta["connection"] should_close = connection_meta["should_close"] current_schema = connection.sql("SELECT current_schema()").fetchall()[0][0] if should_close: connection.close() return current_schema @classmethod def list_tables( cls, *, catalog_pattern: tp.Optional[str] = None, schema_pattern: tp.Optional[str] = None, table_pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, catalog: tp.Optional[str] = None, schema: tp.Optional[str] = None, incl_system: bool = False, incl_temporary: bool = False, incl_views: bool = True, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, ) -> tp.List[str]: """List all tables and views. If `schema` is None, searches for all schema names in the database and prefixes each table with the respective catalog and schema name (unless there's only one schema which is the current schema or `schema` is `current_schema`). If `schema` is provided, returns the tables corresponding to this schema without a prefix. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against `schema_pattern` and each table against `table_pattern`.""" if connection_config is None: connection_config = {} connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) connection = connection_meta["connection"] should_close = connection_meta["should_close"] if catalog is None: catalogs = cls.list_catalogs( pattern=catalog_pattern, use_regex=use_regex, sort=sort, incl_system=incl_system, connection=connection, connection_config=connection_config, ) if catalog_pattern is None and len(catalogs) == 1: prefix_catalog = False else: prefix_catalog = True else: catalogs = [catalog] prefix_catalog = False current_schema = cls.get_current_schema( connection=connection, connection_config=connection_config, ) if schema is None: catalogs_schemas = [] for catalog in catalogs: catalog_schemas = cls.list_schemas( schema_pattern=schema_pattern, use_regex=use_regex, sort=sort, catalog=catalog, incl_system=incl_system, connection=connection, connection_config=connection_config, ) for schema in catalog_schemas: catalogs_schemas.append((catalog, schema)) if len(catalogs_schemas) == 1 and catalogs_schemas[0][1] == current_schema: prefix_schema = False else: prefix_schema = True else: if schema == "current_schema": schema = current_schema catalogs_schemas = [] for catalog in catalogs: catalogs_schemas.append((catalog, schema)) prefix_schema = prefix_catalog tables_df = connection.sql("SELECT * FROM information_schema.tables").df() tables = [] for catalog, schema in catalogs_schemas: all_tables = [] all_tables.extend( tables_df[ (tables_df["table_catalog"] == catalog) & (tables_df["table_schema"] == schema) & (tables_df["table_type"] == "BASE TABLE") ]["table_name"].tolist() ) if incl_temporary: all_tables.extend( tables_df[ (tables_df["table_catalog"] == catalog) & (tables_df["table_schema"] == schema) & (tables_df["table_type"] == "LOCAL TEMPORARY") ]["table_name"].tolist() ) if incl_views: all_tables.extend( tables_df[ (tables_df["table_catalog"] == catalog) & (tables_df["table_schema"] == schema) & (tables_df["table_type"] == "VIEW") ]["table_name"].tolist() ) for table in all_tables: if table_pattern is not None: if not cls.key_match(table, table_pattern, use_regex=use_regex): continue if not prefix_catalog and prefix_schema: table = schema + ":" + table elif prefix_catalog or prefix_schema: table = catalog + ":" + schema + ":" + table tables.append(table) if should_close: connection.close() if sort: return sorted(dict.fromkeys(tables)) return list(dict.fromkeys(tables)) @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, catalog: tp.Optional[str] = None, schema: tp.Optional[str] = None, list_tables_kwargs: tp.KwargsLike = None, read_path: tp.Optional[tp.PathLike] = None, read_format: tp.Optional[str] = None, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, ) -> tp.Kwargs: keys_meta = DBData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) if keys_meta["keys"] is None: if cls.has_key_dict(catalog): raise ValueError("Cannot populate keys if catalog is defined per key") if cls.has_key_dict(schema): raise ValueError("Cannot populate keys if schema is defined per key") if cls.has_key_dict(list_tables_kwargs): raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key") if cls.has_key_dict(connection): raise ValueError("Cannot populate keys if connection is defined per key") if cls.has_key_dict(connection_config): raise ValueError("Cannot populate keys if connection_config is defined per key") if cls.has_key_dict(read_path): raise ValueError("Cannot populate keys if read_path is defined per key") if cls.has_key_dict(read_format): raise ValueError("Cannot populate keys if read_format is defined per key") if read_path is not None or read_format is not None: if read_path is None: read_path = "." if read_format is not None: read_format = read_format.lower() checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format") keys_meta["keys"] = FileData.list_paths(read_path, extension=read_format) else: if list_tables_kwargs is None: list_tables_kwargs = {} keys_meta["keys"] = cls.list_tables( catalog=catalog, schema=schema, connection=connection, connection_config=connection_config, **list_tables_kwargs, ) return keys_meta @classmethod def pull( cls: tp.Type[DuckDBDataT], keys: tp.Union[tp.MaybeKeys] = None, *, keys_are_features: tp.Optional[bool] = None, features: tp.Union[tp.MaybeFeatures] = None, symbols: tp.Union[tp.MaybeSymbols] = None, catalog: tp.Optional[str] = None, schema: tp.Optional[str] = None, list_tables_kwargs: tp.KwargsLike = None, read_path: tp.Optional[tp.PathLike] = None, read_format: tp.Optional[str] = None, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, share_connection: tp.Optional[bool] = None, **kwargs, ) -> DuckDBDataT: """Override `vectorbtpro.data.base.Data.pull` to resolve and share the connection among the keys and use the table names available in the database in case no keys were provided.""" if share_connection is None: if not cls.has_key_dict(connection) and not cls.has_key_dict(connection_config): share_connection = True else: share_connection = False if share_connection: if connection_config is None: connection_config = {} connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) connection = connection_meta["connection"] should_close = connection_meta["should_close"] else: should_close = False keys_meta = cls.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, catalog=catalog, schema=schema, list_tables_kwargs=list_tables_kwargs, read_path=read_path, read_format=read_format, connection=connection, connection_config=connection_config, ) keys = keys_meta["keys"] if isinstance(read_path, key_dict): new_read_path = read_path.copy() else: new_read_path = key_dict() if isinstance(keys, dict): new_keys = {} for k, v in keys.items(): if isinstance(k, Path): new_k = FileData.path_to_key(k) new_read_path[new_k] = k k = new_k new_keys[k] = v keys = new_keys elif cls.has_multiple_keys(keys): new_keys = [] for k in keys: if isinstance(k, Path): new_k = FileData.path_to_key(k) new_read_path[new_k] = k k = new_k new_keys.append(k) keys = new_keys else: if isinstance(keys, Path): new_keys = FileData.path_to_key(keys) new_read_path[new_keys] = keys keys = new_keys if len(new_read_path) > 0: read_path = new_read_path keys_are_features = keys_meta["keys_are_features"] outputs = super(DBData, cls).pull( keys, keys_are_features=keys_are_features, catalog=catalog, schema=schema, read_path=read_path, read_format=read_format, connection=connection, connection_config=connection_config, **kwargs, ) if should_close: connection.close() return outputs @classmethod def format_write_option(cls, option: tp.Any) -> str: """Format a write option.""" if isinstance(option, str): return f"'{option}'" if isinstance(option, (tuple, list)): return "(" + ", ".join(map(str, option)) + ")" if isinstance(option, dict): return "{" + ", ".join(map(lambda y: f"{y[0]}: {cls.format_write_option(y[1])}", option.items())) + "}" return f"{option}" @classmethod def format_write_options(cls, options: tp.Union[str, dict]) -> str: """Format write options.""" if isinstance(options, str): return options new_options = [] for k, v in options.items(): new_options.append(f"{k.upper()} {cls.format_write_option(v)}") return ", ".join(new_options) @classmethod def format_read_option(cls, option: tp.Any) -> str: """Format a read option.""" if isinstance(option, str): return f"'{option}'" if isinstance(option, (tuple, list)): return "[" + ", ".join(map(cls.format_read_option, option)) + "]" if isinstance(option, dict): return "{" + ", ".join(map(lambda y: f"'{y[0]}': {cls.format_read_option(y[1])}", option.items())) + "}" return f"{option}" @classmethod def format_read_options(cls, options: tp.Union[str, dict]) -> str: """Format read options.""" if isinstance(options, str): return options new_options = [] for k, v in options.items(): new_options.append(f"{k.lower()}={cls.format_read_option(v)}") return ", ".join(new_options) @classmethod def fetch_key( cls, key: str, table: tp.Optional[str] = None, schema: tp.Optional[str] = None, catalog: tp.Optional[str] = None, read_path: tp.Optional[tp.PathLike] = None, read_format: tp.Optional[str] = None, read_options: tp.Union[None, str, dict] = None, query: tp.Union[None, str, DuckDBPyRelationT] = None, connection: tp.Union[None, str, DuckDBPyConnectionT] = None, connection_config: tp.KwargsLike = None, start: tp.Optional[tp.Any] = None, end: tp.Optional[tp.Any] = None, align_dates: tp.Optional[bool] = None, parse_dates: tp.Union[None, bool, tp.Sequence[str]] = None, to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None, tz: tp.TimezoneLike = None, index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None, squeeze: tp.Optional[bool] = None, df_kwargs: tp.KwargsLike = None, **sql_kwargs, ) -> tp.KeyData: """Fetch a feature or symbol from a DuckDB database. Can use a table name (which defaults to the key) or a custom query. Args: key (str): Feature or symbol. If `table` and `query` are both None, becomes the table name. Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored. table (str): Table name. Cannot be used together with `file` or `query`. schema (str): Schema name. Cannot be used together with `file` or `query`. catalog (str): Catalog name. Cannot be used together with ``file` or query`. read_path (path_like): Path to a file to read. Cannot be used together with `table`, `schema`, `catalog`, or `query`. read_format (str): Format of the file to read. Allowed values are "csv", "parquet", and "json". Requires `read_path` to be set. read_options (str or dict): Options used to read the file. Requires `read_path` and `read_format` to be set. Uses `DuckDBData.format_read_options` to transform a dictionary to a string. query (str or DuckDBPyRelation): Custom query. Cannot be used together with `catalog`, `schema`, and `table`. connection (str or object): See `DuckDBData.resolve_connection`. connection_config (dict): See `DuckDBData.resolve_connection`. start (any): Start datetime (if datetime index) or any other start value. Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True and the index is a datetime index. Otherwise, you must ensure the correct type is provided. Cannot be used together with `query`. Include the condition into the query. end (any): End datetime (if datetime index) or any other end value. Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True and the index is a datetime index. Otherwise, you must ensure the correct type is provided. Cannot be used together with `query`. Include the condition into the query. align_dates (bool): Whether to align `start` and `end` to the timezone of the index. Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index. parse_dates (bool or sequence of str): See `DuckDBData.prepare_dt`. to_utc (bool, str, or sequence of str): See `DuckDBData.prepare_dt`. tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. index_col (int, str, or list): One or more columns that should become the index. squeeze (int): Whether to squeeze a DataFrame with one column into a Series. df_kwargs (dict): Keyword arguments passed to `relation.df` to convert a relation to a DataFrame. **sql_kwargs: Other keyword arguments passed to `connection.execute` to run a SQL query. For defaults, see `custom.duckdb` in `vectorbtpro._settings.data`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("duckdb") from duckdb import DuckDBPyRelation if connection_config is None: connection_config = {} connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config) connection = connection_meta["connection"] should_close = connection_meta["should_close"] if catalog is not None and query is not None: raise ValueError("Cannot use catalog and query together") if schema is not None and query is not None: raise ValueError("Cannot use schema and query together") if table is not None and query is not None: raise ValueError("Cannot use table and query together") if read_path is not None and query is not None: raise ValueError("Cannot use read_path and query together") if read_path is not None and (catalog is not None or schema is not None or table is not None): raise ValueError("Cannot use read_path and catalog/schema/table together") if table is None and read_path is None and read_format is None and query is None: if ":" in key: key_parts = key.split(":") if len(key_parts) == 2: schema, table = key_parts else: catalog, schema, table = key_parts else: table = key if read_format is not None: read_format = read_format.lower() checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format") if read_path is None: read_path = (Path(".") / key).with_suffix("." + read_format) else: if read_path is not None: if isinstance(read_path, str): read_path = Path(read_path) if read_path.suffix[1:] in ["csv", "parquet", "json"]: read_format = read_path.suffix[1:] if read_path is not None: if isinstance(read_path, Path): read_path = str(read_path) read_path = cls.format_read_option(read_path) if read_options is not None: if read_format is None: raise ValueError("Must provide read_format for read_options") read_options = cls.format_read_options(read_options) catalog = cls.resolve_custom_setting(catalog, "catalog") schema = cls.resolve_custom_setting(schema, "schema") start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") align_dates = cls.resolve_custom_setting(align_dates, "align_dates") parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates") to_utc = cls.resolve_custom_setting(to_utc, "to_utc") tz = cls.resolve_custom_setting(tz, "tz") index_col = cls.resolve_custom_setting(index_col, "index_col") squeeze = cls.resolve_custom_setting(squeeze, "squeeze") df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True) sql_kwargs = cls.resolve_custom_setting(sql_kwargs, "sql_kwargs", merge=True) if query is None: if read_path is not None: if read_options is not None: query = f"SELECT * FROM read_{read_format}({read_path}, {read_options})" elif read_format is not None: query = f"SELECT * FROM read_{read_format}({read_path})" else: query = f"SELECT * FROM {read_path}" else: if catalog is not None: if schema is None: schema = cls.get_current_schema( connection=connection, connection_config=connection_config, ) query = f'SELECT * FROM "{catalog}"."{schema}"."{table}"' elif schema is not None: query = f'SELECT * FROM "{schema}"."{table}"' else: query = f'SELECT * FROM "{table}"' if start is not None or end is not None: if index_col is None: raise ValueError("Must provide index column for filtering by start and end") if not checks.is_int(index_col) and not isinstance(index_col, str): raise ValueError("Index column must be integer or string for filtering by start and end") if checks.is_int(index_col) or align_dates: metadata_df = connection.sql("DESCRIBE " + query + " LIMIT 1").df() else: metadata_df = None if checks.is_int(index_col): index_name = metadata_df["column_name"].tolist()[0] else: index_name = index_col if parse_dates: index_column_type = metadata_df[metadata_df["column_name"] == index_name]["column_type"].item() if index_column_type in ( "TIMESTAMP_NS", "TIMESTAMP_MS", "TIMESTAMP_S", "TIMESTAMP", "DATETIME", ): if start is not None: if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and index_name in to_utc) ): start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") start = dt.to_naive_datetime(start) else: start = dt.to_naive_datetime(start, tz=tz) if end is not None: if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and index_name in to_utc) ): end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") end = dt.to_naive_datetime(end) else: end = dt.to_naive_datetime(end, tz=tz) elif index_column_type in ("TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"): if start is not None: if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and index_name in to_utc) ): start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") else: start = dt.to_tzaware_datetime(start, naive_tz=tz) if end is not None: if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and index_name in to_utc) ): end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") else: end = dt.to_tzaware_datetime(end, naive_tz=tz) if start is not None and end is not None: query += f' WHERE "{index_name}" >= $start AND "{index_name}" < $end' elif start is not None: query += f' WHERE "{index_name}" >= $start' elif end is not None: query += f' WHERE "{index_name}" < $end' params = sql_kwargs.get("params", None) if params is None: params = {} else: params = dict(params) if not isinstance(params, dict): raise ValueError("Parameters must be a dictionary for filtering by start and end") if start is not None: if "start" in params: raise ValueError("Start is already in params") params["start"] = start if end is not None: if "end" in params: raise ValueError("End is already in params") params["end"] = end sql_kwargs["params"] = params else: if start is not None: raise ValueError("Start cannot be applied to custom queries") if end is not None: raise ValueError("End cannot be applied to custom queries") if not isinstance(query, DuckDBPyRelation): relation = connection.sql(query, **sql_kwargs) else: relation = query obj = relation.df(**df_kwargs) if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index): if index_col is not None: if checks.is_int(index_col): keys = obj.columns[index_col] elif isinstance(index_col, str): keys = index_col else: keys = [] for col in index_col: if checks.is_int(col): keys.append(obj.columns[col]) else: keys.append(col) obj = obj.set_index(keys) if not isinstance(obj.index, pd.MultiIndex): if obj.index.name == "index": obj.index.name = None obj = cls.prepare_dt(obj, to_utc=to_utc, parse_dates=parse_dates) if not isinstance(obj.index, pd.MultiIndex): if obj.index.name == "index": obj.index.name = None if isinstance(obj.index, pd.DatetimeIndex) and tz is None: tz = obj.index.tz if isinstance(obj, pd.DataFrame) and squeeze: obj = obj.squeeze("columns") if isinstance(obj, pd.Series) and obj.name == "0": obj.name = None if should_close: connection.close() return obj, dict(tz=tz) @classmethod def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData: """Fetch the table of a feature. Uses `DuckDBData.fetch_key`.""" return cls.fetch_key(feature, **kwargs) @classmethod def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData: """Fetch the table for a symbol. Uses `DuckDBData.fetch_key`.""" return cls.fetch_key(symbol, **kwargs) def update_key(self, key: str, from_last_index: tp.Optional[bool] = None, **kwargs) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) pre_kwargs = merge_dicts(fetch_kwargs, kwargs) if from_last_index is None: if pre_kwargs.get("query", None) is not None: from_last_index = False else: from_last_index = True if from_last_index: fetch_kwargs["start"] = self.select_last_index(key) kwargs = merge_dicts(fetch_kwargs, kwargs) if self.feature_oriented: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: str, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `DuckDBData.update_key`.""" return self.update_key(feature, **kwargs) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `DuckDBData.update_key`.""" return self.update_key(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FeatherData`.""" from pathlib import Path import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.file import FileData from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts __all__ = [ "FeatherData", ] __pdoc__ = {} FeatherDataT = tp.TypeVar("FeatherDataT", bound="FeatherData") class FeatherData(FileData): """Data class for fetching Feather data using PyArrow.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.feather") @classmethod def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]: if not isinstance(path, Path): path = Path(path) if path.exists() and path.is_dir(): path = path / "*.feather" return cls.match_path(path, **match_path_kwargs) @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, paths: tp.Any = None, ) -> tp.Kwargs: keys_meta = FileData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) if keys_meta["keys"] is None and paths is None: keys_meta["keys"] = "*.feather" return keys_meta @classmethod def fetch_key( cls, key: tp.Key, path: tp.Any = None, tz: tp.TimezoneLike = None, index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None, squeeze: tp.Optional[bool] = None, **read_kwargs, ) -> tp.KeyData: """Fetch the Feather file of a feature or symbol. Args: key (hashable): Feature or symbol. path (str): Path. If `path` is None, uses `key` as the path to the Feather file. tz (any): Target timezone. See `vectorbtpro.utils.datetime_.to_timezone`. index_col (int, str, or sequence): Position(s) or name(s) of column(s) that should become the index. Will only apply if the fetched object has a default index. squeeze (int): Whether to squeeze a DataFrame with one column into a Series. **read_kwargs: Other keyword arguments passed to `pd.read_feather`. See https://pandas.pydata.org/docs/reference/api/pandas.read_feather.html for other arguments. For defaults, see `custom.feather` in `vectorbtpro._settings.data`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pyarrow") tz = cls.resolve_custom_setting(tz, "tz") index_col = cls.resolve_custom_setting(index_col, "index_col") if index_col is False: index_col = None squeeze = cls.resolve_custom_setting(squeeze, "squeeze") read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) if path is None: path = key obj = pd.read_feather(path, **read_kwargs) if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index): if index_col is not None: if checks.is_int(index_col): keys = obj.columns[index_col] elif isinstance(index_col, str): keys = index_col else: keys = [] for col in index_col: if checks.is_int(col): keys.append(obj.columns[col]) else: keys.append(col) obj = obj.set_index(keys) if not isinstance(obj.index, pd.MultiIndex): if obj.index.name == "index": obj.index.name = None if isinstance(obj.index, pd.DatetimeIndex) and tz is None: tz = obj.index.tz if isinstance(obj, pd.DataFrame) and squeeze: obj = obj.squeeze("columns") if isinstance(obj, pd.Series) and obj.name == "0": obj.name = None return obj, dict(tz=tz) @classmethod def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Fetch the Feather file of a feature. Uses `FeatherData.fetch_key`.""" return cls.fetch_key(feature, **kwargs) @classmethod def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Fetch the Feather file of a symbol. Uses `FeatherData.fetch_key`.""" return cls.fetch_key(symbol, **kwargs) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `FeatherData.update_key` with `key_is_feature=True`.""" return self.update_key(feature, key_is_feature=True, **kwargs) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `FeatherData.update_key` with `key_is_feature=False`.""" return self.update_key(symbol, key_is_feature=False, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FileData`.""" import re from glob import glob from pathlib import Path from vectorbtpro import _typing as tp from vectorbtpro.data.base import key_dict from vectorbtpro.data.custom.local import LocalData from vectorbtpro.utils import checks __all__ = [ "FileData", ] __pdoc__ = {} FileDataT = tp.TypeVar("FileDataT", bound="FileData") class FileData(LocalData): """Data class for fetching file data.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.file") @classmethod def is_dir_match(cls, path: tp.PathLike) -> bool: """Return whether a directory is a valid match.""" return False @classmethod def is_file_match(cls, path: tp.PathLike) -> bool: """Return whether a file is a valid match.""" return True @classmethod def match_path( cls, path: tp.PathLike, match_regex: tp.Optional[str] = None, sort_paths: bool = True, recursive: bool = True, extension: tp.Optional[str] = None, **kwargs, ) -> tp.List[Path]: """Get the list of all paths matching a path. If `FileData.is_dir_match` returns True for a directory, it gets returned as-is. Otherwise, iterates through all files in that directory and invokes `FileData.is_file_match`. If a pattern was provided, these methods aren't invoked.""" if not isinstance(path, Path): path = Path(path) if path.exists(): if path.is_dir() and not cls.is_dir_match(path): sub_paths = [] for p in path.iterdir(): if p.is_dir() and cls.is_dir_match(p): sub_paths.append(p) if p.is_file() and cls.is_file_match(p): if extension is None or p.suffix == "." + extension: sub_paths.append(p) else: sub_paths = [path] else: sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)]) if match_regex is not None: sub_paths = [p for p in sub_paths if re.match(match_regex, str(p))] if sort_paths: sub_paths = sorted(sub_paths) return sub_paths @classmethod def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]: """List all features or symbols under a path.""" return cls.match_path(path, **match_path_kwargs) @classmethod def path_to_key(cls, path: tp.PathLike, **kwargs) -> str: """Convert a path into a feature or symbol.""" return Path(path).stem @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, paths: tp.Any = None, ) -> tp.Kwargs: return LocalData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) @classmethod def pull( cls: tp.Type[FileDataT], keys: tp.Union[tp.MaybeKeys] = None, *, keys_are_features: tp.Optional[bool] = None, features: tp.Union[tp.MaybeFeatures] = None, symbols: tp.Union[tp.MaybeSymbols] = None, paths: tp.Any = None, match_paths: tp.Optional[bool] = None, match_regex: tp.Optional[str] = None, sort_paths: tp.Optional[bool] = None, match_path_kwargs: tp.KwargsLike = None, path_to_key_kwargs: tp.KwargsLike = None, **kwargs, ) -> FileDataT: """Override `vectorbtpro.data.base.Data.pull` to take care of paths. Use either features, symbols, or `paths` to specify the path to one or multiple files. Allowed are paths in a string or `pathlib.Path` format, or string expressions accepted by `glob.glob`. Set `match_paths` to False to not parse paths and behave like a regular `vectorbtpro.data.base.Data` instance. For defaults, see `custom.local` in `vectorbtpro._settings.data`. """ keys_meta = cls.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, paths=paths, ) keys = keys_meta["keys"] keys_are_features = keys_meta["keys_are_features"] dict_type = keys_meta["dict_type"] match_paths = cls.resolve_custom_setting(match_paths, "match_paths") match_regex = cls.resolve_custom_setting(match_regex, "match_regex") sort_paths = cls.resolve_custom_setting(sort_paths, "sort_paths") if match_paths: sync = False if paths is None: paths = keys sync = True elif keys is None: sync = True if paths is None: if keys_are_features: raise ValueError("At least features or paths must be set") else: raise ValueError("At least symbols or paths must be set") if match_path_kwargs is None: match_path_kwargs = {} if path_to_key_kwargs is None: path_to_key_kwargs = {} single_key = False if isinstance(keys, (str, Path)): # Single key keys = [keys] single_key = True single_path = False if isinstance(paths, (str, Path)): # Single path paths = [paths] single_path = True if sync: single_key = True cls.check_dict_type(paths, "paths", dict_type=dict_type) if isinstance(paths, key_dict): # Dict of path per key if sync: keys = list(paths.keys()) elif len(keys) != len(paths): if keys_are_features: raise ValueError("The number of features must be equal to the number of matched paths") else: raise ValueError("The number of symbols must be equal to the number of matched paths") elif checks.is_iterable(paths) or checks.is_sequence(paths): # Multiple paths matched_paths = [ p for sub_path in paths for p in cls.match_path( sub_path, match_regex=match_regex, sort_paths=sort_paths, **match_path_kwargs, ) ] if len(matched_paths) == 0: raise FileNotFoundError(f"No paths could be matched with {paths}") if sync: keys = [] paths = key_dict() for p in matched_paths: s = cls.path_to_key(p, **path_to_key_kwargs) keys.append(s) paths[s] = p elif len(keys) != len(matched_paths): if keys_are_features: raise ValueError("The number of features must be equal to the number of matched paths") else: raise ValueError("The number of symbols must be equal to the number of matched paths") else: paths = key_dict({s: matched_paths[i] for i, s in enumerate(keys)}) if len(matched_paths) == 1 and single_path: paths = matched_paths[0] else: raise TypeError(f"Path '{paths}' is not supported") if len(keys) == 1 and single_key: keys = keys[0] return super(FileData, cls).pull( keys, keys_are_features=keys_are_features, path=paths, **kwargs, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FinPyData`.""" from itertools import product import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts try: if not tp.TYPE_CHECKING: raise ImportError from findatapy.market import Market as MarketT from findatapy.util import ConfigManager as ConfigManagerT except ImportError: MarketT = "Market" ConfigManagerT = "ConfigManager" __all__ = [ "FinPyData", ] FinPyDataT = tp.TypeVar("FinPyDataT", bound="FinPyData") class FinPyData(RemoteData): """Data class for fetching using findatapy. See https://github.com/cuemacro/findatapy for API. See `FinPyData.fetch_symbol` for arguments. Usage: * Pull data (keyword argument format): ```pycon >>> data = vbt.FinPyData.pull( ... "EURUSD", ... start="14 June 2016", ... end="15 June 2016", ... timeframe="tick", ... category="fx", ... fields=["bid", "ask"], ... data_source="dukascopy" ... ) ``` * Pull data (string format): ```pycon >>> data = vbt.FinPyData.pull( ... "fx.dukascopy.tick.NYC.EURUSD.bid,ask", ... start="14 June 2016", ... end="15 June 2016", ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.finpy") @classmethod def resolve_market( cls, market: tp.Optional[MarketT] = None, **market_config, ) -> MarketT: """Resolve the market. If provided, must be of the type `findatapy.market.market.Market`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("findatapy") from findatapy.market import Market, MarketDataGenerator market = cls.resolve_custom_setting(market, "market") if market_config is None: market_config = {} has_market_config = len(market_config) > 0 market_config = cls.resolve_custom_setting(market_config, "market_config", merge=True) if "market_data_generator" not in market_config: market_config["market_data_generator"] = MarketDataGenerator() if market is None: market = Market(**market_config) elif has_market_config: raise ValueError("Cannot apply market_config to already initialized market") return market @classmethod def resolve_config_manager( cls, config_manager: tp.Optional[ConfigManagerT] = None, **config_manager_config, ) -> MarketT: """Resolve the config manager. If provided, must be of the type `findatapy.util.ConfigManager`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("findatapy") from findatapy.util import ConfigManager config_manager = cls.resolve_custom_setting(config_manager, "config_manager") if config_manager_config is None: config_manager_config = {} has_config_manager_config = len(config_manager_config) > 0 config_manager_config = cls.resolve_custom_setting(config_manager_config, "config_manager_config", merge=True) if config_manager is None: config_manager = ConfigManager().get_instance(**config_manager_config) elif has_config_manager_config: raise ValueError("Cannot apply config_manager_config to already initialized config_manager") return config_manager @classmethod def list_symbols( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, config_manager: tp.Optional[ConfigManagerT] = None, config_manager_config: tp.KwargsLike = None, category: tp.Optional[tp.MaybeList[str]] = None, data_source: tp.Optional[tp.MaybeList[str]] = None, freq: tp.Optional[tp.MaybeList[str]] = None, cut: tp.Optional[tp.MaybeList[str]] = None, tickers: tp.Optional[tp.MaybeList[str]] = None, dict_filter: tp.DictLike = None, smart_group: bool = False, return_fields: tp.Optional[tp.MaybeList[str]] = None, combine_parts: bool = True, ) -> tp.List[str]: """List all symbols. Passes most arguments to `findatapy.util.ConfigManager.free_form_tickers_regex_query`. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.""" if config_manager_config is None: config_manager_config = {} config_manager = cls.resolve_config_manager(config_manager=config_manager, **config_manager_config) if dict_filter is None: dict_filter = {} def_ret_fields = ["category", "data_source", "freq", "cut", "tickers"] if return_fields is None: ret_fields = def_ret_fields elif isinstance(return_fields, str): if return_fields.lower() == "all": ret_fields = def_ret_fields + ["fields"] else: ret_fields = [return_fields] else: ret_fields = return_fields df = config_manager.free_form_tickers_regex_query( category=category, data_source=data_source, freq=freq, cut=cut, tickers=tickers, dict_filter=dict_filter, smart_group=smart_group, ret_fields=ret_fields, ) all_symbols = [] for _, row in df.iterrows(): parts = [] if "category" in row.index: parts.append(row.loc["category"]) if "data_source" in row.index: parts.append(row.loc["data_source"]) if "freq" in row.index: parts.append(row.loc["freq"]) if "cut" in row.index: parts.append(row.loc["cut"]) if "tickers" in row.index: parts.append(row.loc["tickers"]) if "fields" in row.index: parts.append(row.loc["fields"]) if combine_parts: split_parts = [part.split(",") for part in parts] combinations = list(product(*split_parts)) else: combinations = [parts] for symbol in [".".join(combination) for combination in combinations]: if pattern is not None: if not cls.key_match(symbol, pattern, use_regex=use_regex): continue all_symbols.append(symbol) if sort: return sorted(dict.fromkeys(all_symbols)) return list(dict.fromkeys(all_symbols)) @classmethod def fetch_symbol( cls, symbol: str, market: tp.Optional[MarketT] = None, market_config: tp.KwargsLike = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, **request_kwargs, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from findatapy. Args: symbol (str): Symbol. Also accepts the format such as "fx.bloomberg.daily.NYC.EURUSD.close". The fields `freq`, `cut`, `tickers`, and `fields` here are optional. market (findatapy.market.market.Market): Market. See `FinPyData.resolve_market`. market_config (dict): Client config. See `FinPyData.resolve_market`. start (any): Start datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): End datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. **request_kwargs: Other keyword arguments passed to `findatapy.market.marketdatarequest.MarketDataRequest`. For defaults, see `custom.finpy` in `vectorbtpro._settings.data`. Global settings can be provided per exchange id using the `exchanges` dictionary. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("findatapy") from findatapy.market import MarketDataRequest if market_config is None: market_config = {} market = cls.resolve_market(market=market, **market_config) start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") request_kwargs = cls.resolve_custom_setting(request_kwargs, "request_kwargs", merge=True) split = dt.split_freq_str(timeframe) if split is None: raise ValueError(f"Invalid timeframe: '{timeframe}'") multiplier, unit = split if unit == "s": unit = "second" freq = timeframe elif unit == "m": unit = "minute" freq = timeframe elif unit == "h": unit = "hourly" freq = timeframe elif unit == "D": unit = "daily" freq = timeframe elif unit == "W": unit = "weekly" freq = timeframe elif unit == "M": unit = "monthly" freq = timeframe elif unit == "Q": unit = "quarterly" freq = timeframe elif unit == "Y": unit = "annually" freq = timeframe else: freq = None if "resample" in request_kwargs: freq = request_kwargs["resample"] if start is not None: start = dt.to_naive_datetime(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) if end is not None: end = dt.to_naive_datetime(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) if "md_request" in request_kwargs: md_request = request_kwargs["md_request"] elif "md_request_df" in request_kwargs: md_request = market.create_md_request_from_dataframe( md_request_df=request_kwargs["md_request_df"], start_date=start, finish_date=end, freq_mult=multiplier, freq=unit, **request_kwargs, ) elif "md_request_str" in request_kwargs: md_request = market.create_md_request_from_str( md_request_str=request_kwargs["md_request_str"], start_date=start, finish_date=end, freq_mult=multiplier, freq=unit, **request_kwargs, ) elif "md_request_dict" in request_kwargs: md_request = market.create_md_request_from_dict( md_request_dict=request_kwargs["md_request_dict"], start_date=start, finish_date=end, freq_mult=multiplier, freq=unit, **request_kwargs, ) elif symbol.count(".") >= 2: md_request = market.create_md_request_from_str( md_request_str=symbol, start_date=start, finish_date=end, freq_mult=multiplier, freq=unit, **request_kwargs, ) else: md_request = MarketDataRequest( tickers=symbol, start_date=start, finish_date=end, freq_mult=multiplier, freq=unit, **request_kwargs, ) df = market.fetch_market(md_request=md_request) if df is None: return None if isinstance(md_request.tickers, str): ticker = md_request.tickers elif len(md_request.tickers) == 1: ticker = md_request.tickers[0] else: ticker = None if ticker is not None: df.columns = df.columns.map(lambda x: x.replace(ticker + ".", "")) if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: df = df.tz_localize("utc") return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `GBMOHLCData`.""" import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import broadcast_array_to from vectorbtpro.data import nb from vectorbtpro.data.custom.synthetic import SyntheticData from vectorbtpro.ohlcv import nb as ohlcv_nb from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.random_ import set_seed from vectorbtpro.utils.template import substitute_templates __all__ = [ "GBMOHLCData", ] __pdoc__ = {} class GBMOHLCData(SyntheticData): """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_1d_nb` and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.gbm_ohlc") @classmethod def generate_symbol( cls, symbol: tp.Symbol, index: tp.Index, n_ticks: tp.Optional[tp.ArrayLike] = None, start_value: tp.Optional[float] = None, mean: tp.Optional[float] = None, std: tp.Optional[float] = None, dt: tp.Optional[float] = None, seed: tp.Optional[int] = None, jitted: tp.JittedOption = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.SymbolData: """Generate a symbol. Args: symbol (hashable): Symbol. index (pd.Index): Pandas index. n_ticks (int or array_like): Number of ticks per bar. Flexible argument. Can be a template with a context containing `symbol` and `index`. start_value (float): Value at time 0. Does not appear as the first value in the output data. mean (float): Drift, or mean of the percentage change. std (float): Standard deviation of the percentage change. dt (float): Time change (one period of time). seed (int): Seed to make output deterministic. jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. template_context (dict): Context used to substitute templates. For defaults, see `custom.gbm` in `vectorbtpro._settings.data`. !!! note When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`. """ n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks") template_context = merge_dicts(dict(symbol=symbol, index=index), template_context) n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks") n_ticks = broadcast_array_to(n_ticks, len(index)) start_value = cls.resolve_custom_setting(start_value, "start_value") mean = cls.resolve_custom_setting(mean, "mean") std = cls.resolve_custom_setting(std, "std") dt = cls.resolve_custom_setting(dt, "dt") seed = cls.resolve_custom_setting(seed, "seed") if seed is not None: set_seed(seed) func = jit_reg.resolve_option(nb.generate_gbm_data_1d_nb, jitted) ticks = func( np.sum(n_ticks), start_value=start_value, mean=mean, std=std, dt=dt, ) func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted) out = func(ticks, n_ticks) return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"]) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) _ = fetch_kwargs.pop("start_value", None) start_value = self.data[symbol]["Open"].iloc[-1] fetch_kwargs["seed"] = None kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, start_value=start_value, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `GBMData`.""" import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_1d_array from vectorbtpro.data import nb from vectorbtpro.data.custom.synthetic import SyntheticData from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.random_ import set_seed __all__ = [ "GBMData", ] __pdoc__ = {} class GBMData(SyntheticData): """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_nb`.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.gbm") @classmethod def generate_key( cls, key: tp.Key, index: tp.Index, columns: tp.Union[tp.Hashable, tp.IndexLike] = None, start_value: tp.Optional[float] = None, mean: tp.Optional[float] = None, std: tp.Optional[float] = None, dt: tp.Optional[float] = None, seed: tp.Optional[int] = None, jitted: tp.JittedOption = None, **kwargs, ) -> tp.KeyData: """Generate a feature or symbol. Args: key (hashable): Feature or symbol. index (pd.Index): Pandas index. columns (hashable or index_like): Column names. Provide a single value (hashable) to make a Series. start_value (float): Value at time 0. Does not appear as the first value in the output data. mean (float): Drift, or mean of the percentage change. std (float): Standard deviation of the percentage change. dt (float): Time change (one period of time). seed (int): Seed to make output deterministic. jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. For defaults, see `custom.gbm` in `vectorbtpro._settings.data`. !!! note When setting a seed, remember to pass a seed per feature/symbol using `vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally `vectorbtpro.data.base.key_dict`. """ if checks.is_hashable(columns): columns = [columns] make_series = True else: make_series = False if not isinstance(columns, pd.Index): columns = pd.Index(columns) start_value = cls.resolve_custom_setting(start_value, "start_value") mean = cls.resolve_custom_setting(mean, "mean") std = cls.resolve_custom_setting(std, "std") dt = cls.resolve_custom_setting(dt, "dt") seed = cls.resolve_custom_setting(seed, "seed") if seed is not None: set_seed(seed) func = jit_reg.resolve_option(nb.generate_gbm_data_nb, jitted) out = func( (len(index), len(columns)), start_value=to_1d_array(start_value), mean=to_1d_array(mean), std=to_1d_array(std), dt=to_1d_array(dt), ) if make_series: return pd.Series(out[:, 0], index=index, name=columns[0]) return pd.DataFrame(out, index=index, columns=columns) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: fetch_kwargs = self.select_fetch_kwargs(key) fetch_kwargs["start"] = self.select_last_index(key) _ = fetch_kwargs.pop("start_value", None) start_value = self.data[key].iloc[-2] fetch_kwargs["seed"] = None kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, start_value=start_value, **kwargs) return self.fetch_symbol(key, start_value=start_value, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `HDFData`.""" import re from glob import glob from pathlib import Path, PurePath import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.file import FileData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.parsing import get_func_arg_names __all__ = [ "HDFData", ] __pdoc__ = {} class HDFPathNotFoundError(Exception): """Gets raised if the path to an HDF file could not be found.""" pass class HDFKeyNotFoundError(Exception): """Gets raised if the key to an HDF object could not be found.""" pass HDFDataT = tp.TypeVar("HDFDataT", bound="HDFData") class HDFData(FileData): """Data class for fetching HDF data using PyTables.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.hdf") @classmethod def is_hdf_file(cls, path: tp.PathLike) -> bool: """Return whether the path is an HDF file.""" if not isinstance(path, Path): path = Path(path) if path.exists() and path.is_file() and ".hdf" in path.suffixes: return True if path.exists() and path.is_file() and ".hdf5" in path.suffixes: return True if path.exists() and path.is_file() and ".h5" in path.suffixes: return True return False @classmethod def is_file_match(cls, path: tp.PathLike) -> bool: return cls.is_hdf_file(path) @classmethod def split_hdf_path( cls, path: tp.PathLike, key: tp.Optional[str] = None, _full_path: tp.Optional[Path] = None, ) -> tp.Tuple[Path, tp.Optional[str]]: """Split the path to an HDF object into the path to the file and the key.""" path = Path(path) if _full_path is None: _full_path = path if path.exists(): if path.is_dir(): raise HDFPathNotFoundError(f"No HDF files could be matched with {_full_path}") return path, key new_path = path.parent if key is None: new_key = path.name else: new_key = str(Path(path.name) / key) return cls.split_hdf_path(new_path, new_key, _full_path=_full_path) @classmethod def match_path( cls, path: tp.PathLike, match_regex: tp.Optional[str] = None, sort_paths: bool = True, recursive: bool = True, **kwargs, ) -> tp.List[Path]: """Override `FileData.match_path` to return a list of HDF paths (path to file + key) matching a path.""" path = Path(path) if path.exists(): if path.is_dir() and not cls.is_dir_match(path): sub_paths = [] for p in path.iterdir(): if p.is_dir() and cls.is_dir_match(p): sub_paths.append(p) if p.is_file() and cls.is_file_match(p): sub_paths.append(p) key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)] else: with pd.HDFStore(str(path), mode="r") as store: keys = [k[1:] for k in store.keys()] key_paths = [path / k for k in keys] else: try: file_path, key = cls.split_hdf_path(path) with pd.HDFStore(str(file_path), mode="r") as store: keys = [k[1:] for k in store.keys()] if key is None: key_paths = [file_path / k for k in keys] elif key in keys: key_paths = [file_path / key] else: matching_keys = [] for k in keys: if k.startswith(key) or PurePath("/" + str(k)).match("/" + str(key)): matching_keys.append(k) if len(matching_keys) == 0: raise HDFKeyNotFoundError(f"No HDF keys could be matched with {key}") key_paths = [file_path / k for k in matching_keys] except HDFPathNotFoundError: sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)]) if len(sub_paths) == 0 and re.match(r".+\..+", str(path)): base_path = None base_ended = False key_path = None for part in path.parts: part = Path(part) if base_ended: if key_path is None: key_path = part else: key_path /= part else: if re.match(r".+\..+", str(part)): base_ended = True if base_path is None: base_path = part else: base_path /= part sub_paths = list([Path(p) for p in glob(str(base_path), recursive=recursive)]) if key_path is not None: sub_paths = [p / key_path for p in sub_paths] key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)] if match_regex is not None: key_paths = [p for p in key_paths if re.match(match_regex, str(p))] if sort_paths: key_paths = sorted(key_paths) return key_paths @classmethod def path_to_key(cls, path: tp.PathLike, **kwargs) -> str: return Path(path).name @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, paths: tp.Any = None, ) -> tp.Kwargs: keys_meta = FileData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) if keys_meta["keys"] is None and paths is None: keys_meta["keys"] = cls.list_paths() return keys_meta @classmethod def fetch_key( cls, key: tp.Key, path: tp.Any = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, tz: tp.TimezoneLike = None, start_row: tp.Optional[int] = None, end_row: tp.Optional[int] = None, chunk_func: tp.Optional[tp.Callable] = None, **read_kwargs, ) -> tp.KeyData: """Fetch the HDF object of a feature or symbol. Args: key (hashable): Feature or symbol. path (str): Path. Will be resolved with `HDFData.split_hdf_path`. If `path` is None, uses `key` as the path to the HDF file. start (any): Start datetime. Will extract the object's index and compare the index to the date. Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. !!! note Can only be used if the object was saved in the table format! end (any): End datetime. Will extract the object's index and compare the index to the date. Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`. !!! note Can only be used if the object was saved in the table format! tz (any): Target timezone. See `vectorbtpro.utils.datetime_.to_timezone`. start_row (int): Start row (inclusive). Will use it when querying index as well. end_row (int): End row (exclusive). Will use it when querying index as well. chunk_func (callable): Function to select and concatenate chunks from `TableIterator`. Gets called only if `iterator` or `chunksize` are set. **read_kwargs: Other keyword arguments passed to `pd.read_hdf`. See https://pandas.pydata.org/docs/reference/api/pandas.read_hdf.html for other arguments. For defaults, see `custom.hdf` in `vectorbtpro._settings.data`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("tables") from pandas.io.pytables import TableIterator start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") tz = cls.resolve_custom_setting(tz, "tz") start_row = cls.resolve_custom_setting(start_row, "start_row") if start_row is None: start_row = 0 end_row = cls.resolve_custom_setting(end_row, "end_row") read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) if path is None: path = key path = Path(path) file_path, file_key = cls.split_hdf_path(path) if file_key is not None: key = file_key if start is not None or end is not None: hdf_store_arg_names = get_func_arg_names(pd.HDFStore.__init__) hdf_store_kwargs = dict() for k, v in read_kwargs.items(): if k in hdf_store_arg_names: hdf_store_kwargs[k] = v with pd.HDFStore(str(file_path), mode="r", **hdf_store_kwargs) as store: index = store.select_column(key, "index", start=start_row, stop=end_row) if not isinstance(index, pd.Index): index = pd.Index(index) if not isinstance(index, pd.DatetimeIndex): raise TypeError("Cannot filter index that is not DatetimeIndex") if tz is None: tz = index.tz if index.tz is not None: if start is not None: start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=index.tz) if end is not None: end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=index.tz) else: if start is not None: start = dt.to_naive_timestamp(start, tz=tz) if end is not None: end = dt.to_naive_timestamp(end, tz=tz) mask = True if start is not None: mask &= index >= start if end is not None: mask &= index < end mask_indices = np.flatnonzero(mask) if len(mask_indices) == 0: return None start_row += mask_indices[0] end_row = start_row + mask_indices[-1] - mask_indices[0] + 1 obj = pd.read_hdf(file_path, key=key, start=start_row, stop=end_row, **read_kwargs) if isinstance(obj, TableIterator): if chunk_func is None: obj = pd.concat(list(obj), axis=0) else: obj = chunk_func(obj) if isinstance(obj.index, pd.DatetimeIndex) and tz is None: tz = obj.index.tz return obj, dict(last_row=start_row + len(obj.index) - 1, tz=tz) @classmethod def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Fetch the HDF object of a feature. Uses `HDFData.fetch_key`.""" return cls.fetch_key(feature, **kwargs) @classmethod def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Load the HDF object for a symbol. Uses `HDFData.fetch_key`.""" return cls.fetch_key(symbol, **kwargs) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) returned_kwargs = self.select_returned_kwargs(key) fetch_kwargs["start_row"] = returned_kwargs["last_row"] kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `HDFData.update_key` with `key_is_feature=True`.""" return self.update_key(feature, key_is_feature=True, **kwargs) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `HDFData.update_key` with `key_is_feature=False`.""" return self.update_key(symbol, key_is_feature=False, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `LocalData`.""" from vectorbtpro import _typing as tp from vectorbtpro.data.custom.custom import CustomData __all__ = [ "LocalData", ] __pdoc__ = {} class LocalData(CustomData): """Data class for fetching local data.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.local") # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `NDLData`.""" import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts __all__ = [ "NDLData", ] __pdoc__ = {} NDLDataT = tp.TypeVar("NDLDataT", bound="NDLData") class NDLData(RemoteData): """Data class for fetching from Nasdaq Data Link. See https://github.com/Nasdaq/data-link-python for API. See `NDLData.fetch_symbol` for arguments. Usage: * Set up the API key globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.NDLData.set_custom_settings( ... api_key="YOUR_KEY" ... ) ``` * Pull a dataset: ```pycon >>> data = vbt.NDLData.pull( ... "FRED/GDP", ... start="2001-12-31", ... end="2005-12-31" ... ) ``` * Pull a datatable: ```pycon >>> data = vbt.NDLData.pull( ... "MER/F1", ... data_format="datatable", ... compnumber="39102", ... paginate=True ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.ndl") @classmethod def fetch_symbol( cls, symbol: str, api_key: tp.Optional[str] = None, data_format: tp.Optional[str] = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, tz: tp.TimezoneLike = None, column_indices: tp.Optional[tp.MaybeIterable[int]] = None, **params, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Nasdaq Data Link. Args: symbol (str): Symbol. api_key (str): API key. data_format (str): Data format. Supported are "dataset" and "datatable". start (any): Retrieve data rows on and after the specified start date. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): Retrieve data rows up to and including the specified end date. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. column_indices (int or iterable): Request one or more specific columns. Column 0 is the date column and is always returned. Data begins at column 1. **params: Keyword arguments sent as field/value params to Nasdaq Data Link with no interference. For defaults, see `custom.ndl` in `vectorbtpro._settings.data`. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("nasdaqdatalink") import nasdaqdatalink api_key = cls.resolve_custom_setting(api_key, "api_key") data_format = cls.resolve_custom_setting(data_format, "data_format") start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") tz = cls.resolve_custom_setting(tz, "tz") column_indices = cls.resolve_custom_setting(column_indices, "column_indices") if column_indices is not None: if isinstance(column_indices, int): dataset = symbol + "." + str(column_indices) else: dataset = [symbol + "." + str(index) for index in column_indices] else: dataset = symbol params = cls.resolve_custom_setting(params, "params", merge=True) # Establish the timestamps if start is not None: start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") start_date = pd.Timestamp(start).isoformat() if "start_date" not in params: params["start_date"] = start_date else: start_date = None if end is not None: end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") end_date = pd.Timestamp(end).isoformat() if "end_date" not in params: params["end_date"] = end_date else: end_date = None # Collect and format the data if data_format.lower() == "dataset": df = nasdaqdatalink.get( dataset, api_key=api_key, **params, ) else: df = nasdaqdatalink.get_table( dataset, api_key=api_key, **params, ) new_columns = [] for c in df.columns: new_c = c if isinstance(symbol, str): new_c = new_c.replace(symbol + " - ", "") if new_c == "Last": new_c = "Close" new_columns.append(new_c) df = df.rename(columns=dict(zip(df.columns, new_columns))) if df.index.name == "None": df.index.name = None if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: df = df.tz_localize("utc") if isinstance(df.index, pd.DatetimeIndex) and not df.empty: if start is not None: start = dt.to_timestamp(start, tz=df.index.tz) if df.index[0] < start: df = df[df.index >= start] if end is not None: end = dt.to_timestamp(end, tz=df.index.tz) if df.index[-1] >= end: df = df[df.index < end] return df, dict(tz=tz) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `ParquetData`.""" import re from pathlib import Path import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.file import FileData from vectorbtpro.utils.config import merge_dicts __all__ = [ "ParquetData", ] __pdoc__ = {} ParquetDataT = tp.TypeVar("ParquetDataT", bound="ParquetData") class ParquetData(FileData): """Data class for fetching Parquet data using PyArrow or FastParquet.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.parquet") @classmethod def is_parquet_file(cls, path: tp.PathLike) -> bool: """Return whether the path is a Parquet file.""" if not isinstance(path, Path): path = Path(path) if path.exists() and path.is_file() and ".parquet" in path.suffixes: return True return False @classmethod def is_parquet_group_dir(cls, path: tp.PathLike) -> bool: """Return whether the path is a directory that is a group of Parquet partitions. !!! note Assumes the Hive partitioning scheme.""" if not isinstance(path, Path): path = Path(path) if path.exists() and path.is_dir(): partition_regex = r"^(.+)=(.+)" if re.match(partition_regex, path.name): for p in path.iterdir(): if cls.is_parquet_group_dir(p) or cls.is_parquet_file(p): return True return False @classmethod def is_parquet_dir(cls, path: tp.PathLike) -> bool: """Return whether the path is a directory that is a group itself or contains groups of Parquet partitions.""" if cls.is_parquet_group_dir(path): return True if not isinstance(path, Path): path = Path(path) if path.exists() and path.is_dir(): for p in path.iterdir(): if cls.is_parquet_group_dir(p): return True return False @classmethod def is_dir_match(cls, path: tp.PathLike) -> bool: return cls.is_parquet_dir(path) @classmethod def is_file_match(cls, path: tp.PathLike) -> bool: return cls.is_parquet_file(path) @classmethod def list_partition_cols(cls, path: tp.PathLike) -> tp.List[str]: """List partitioning columns under a path. !!! note Assumes the Hive partitioning scheme.""" if not isinstance(path, Path): path = Path(path) partition_cols = [] found_last_level = False while not found_last_level: found_new_level = False for p in path.iterdir(): if cls.is_parquet_group_dir(p): partition_cols.append(p.name.split("=")[0]) path = p found_new_level = True break if not found_new_level: found_last_level = True return partition_cols @classmethod def is_default_partition_col(cls, level: str) -> bool: """Return whether a partitioning column is a default partitioning column.""" return re.match(r"^(\bgroup\b)|(group_\d+)", level) is not None @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, paths: tp.Any = None, ) -> tp.Kwargs: keys_meta = FileData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) if keys_meta["keys"] is None and paths is None: keys_meta["keys"] = cls.list_paths() return keys_meta @classmethod def fetch_key( cls, key: tp.Key, path: tp.Any = None, tz: tp.TimezoneLike = None, squeeze: tp.Optional[bool] = None, keep_partition_cols: tp.Optional[bool] = None, engine: tp.Optional[str] = None, **read_kwargs, ) -> tp.KeyData: """Fetch the Parquet file of a feature or symbol. Args: key (hashable): Feature or symbol. path (str): Path. If `path` is None, uses `key` as the path to the Parquet file. tz (any): Target timezone. See `vectorbtpro.utils.datetime_.to_timezone`. squeeze (int): Whether to squeeze a DataFrame with one column into a Series. keep_partition_cols (bool): Whether to return partitioning columns (if any). If None, will remove any partitioning column that is "group" or "group_{index}". Retrieves the list of partitioning columns with `ParquetData.list_partition_cols`. engine (str): See `pd.read_parquet`. **read_kwargs: Other keyword arguments passed to `pd.read_parquet`. See https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html for other arguments. For defaults, see `custom.parquet` in `vectorbtpro._settings.data`.""" from vectorbtpro.utils.module_ import assert_can_import, assert_can_import_any tz = cls.resolve_custom_setting(tz, "tz") squeeze = cls.resolve_custom_setting(squeeze, "squeeze") keep_partition_cols = cls.resolve_custom_setting(keep_partition_cols, "keep_partition_cols") engine = cls.resolve_custom_setting(engine, "engine") read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True) if engine == "pyarrow": assert_can_import("pyarrow") elif engine == "fastparquet": assert_can_import("fastparquet") elif engine == "auto": assert_can_import_any("pyarrow", "fastparquet") else: raise ValueError(f"Invalid engine: '{engine}'") if path is None: path = key obj = pd.read_parquet(path, engine=engine, **read_kwargs) if keep_partition_cols in (None, False): if cls.is_parquet_dir(path): drop_columns = [] partition_cols = cls.list_partition_cols(path) for col in obj.columns: if col in partition_cols: if keep_partition_cols is False or cls.is_default_partition_col(col): drop_columns.append(col) obj = obj.drop(drop_columns, axis=1) if isinstance(obj.index, pd.DatetimeIndex) and tz is None: tz = obj.index.tz if isinstance(obj, pd.DataFrame) and squeeze: obj = obj.squeeze("columns") if isinstance(obj, pd.Series) and obj.name == "0": obj.name = None return obj, dict(tz=tz) @classmethod def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Fetch the Parquet file of a feature. Uses `ParquetData.fetch_key`.""" return cls.fetch_key(feature, **kwargs) @classmethod def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Fetch the Parquet file of a symbol. Uses `ParquetData.fetch_key`.""" return cls.fetch_key(symbol, **kwargs) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `ParquetData.update_key` with `key_is_feature=True`.""" return self.update_key(feature, key_is_feature=True, **kwargs) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `ParquetData.update_key` with `key_is_feature=False`.""" return self.update_key(symbol, key_is_feature=False, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `PolygonData`.""" import time import traceback from functools import wraps, partial import pandas as pd import requests from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.pbar import ProgressBar from vectorbtpro.utils.warnings_ import warn try: if not tp.TYPE_CHECKING: raise ImportError from polygon import RESTClient as PolygonClientT except ImportError: PolygonClientT = "PolygonClient" __all__ = [ "PolygonData", ] PolygonDataT = tp.TypeVar("PolygonDataT", bound="PolygonData") class PolygonData(RemoteData): """Data class for fetching from Polygon. See https://github.com/polygon-io/client-python for API. See `PolygonData.fetch_symbol` for arguments. Usage: * Set up the API key globally: ```pycon >>> from vectorbtpro import * >>> vbt.PolygonData.set_custom_settings( ... client_config=dict( ... api_key="YOUR_KEY" ... ) ... ) ``` * Pull stock data: ```pycon >>> data = vbt.PolygonData.pull( ... "AAPL", ... start="2021-01-01", ... end="2022-01-01", ... timeframe="1 day" ... ) ``` * Pull crypto data: ```pycon >>> data = vbt.PolygonData.pull( ... "X:BTCUSD", ... start="2021-01-01", ... end="2022-01-01", ... timeframe="1 day" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.polygon") @classmethod def list_symbols( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, client: tp.Optional[PolygonClientT] = None, client_config: tp.DictLike = None, **list_tickers_kwargs, ) -> tp.List[str]: """List all symbols. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`. For supported keyword arguments, see `polygon.RESTClient.list_tickers`.""" if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) all_symbols = [] for ticker in client.list_tickers(**list_tickers_kwargs): symbol = ticker.ticker if pattern is not None: if not cls.key_match(symbol, pattern, use_regex=use_regex): continue all_symbols.append(symbol) if sort: return sorted(dict.fromkeys(all_symbols)) return list(dict.fromkeys(all_symbols)) @classmethod def resolve_client(cls, client: tp.Optional[PolygonClientT] = None, **client_config) -> PolygonClientT: """Resolve the client. If provided, must be of the type `polygon.rest.RESTClient`. Otherwise, will be created using `client_config`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("polygon") from polygon import RESTClient client = cls.resolve_custom_setting(client, "client") if client_config is None: client_config = {} has_client_config = len(client_config) > 0 client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) if client is None: client = RESTClient(**client_config) elif has_client_config: raise ValueError("Cannot apply client_config to already initialized client") return client @classmethod def fetch_symbol( cls, symbol: str, client: tp.Optional[PolygonClientT] = None, client_config: tp.DictLike = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, adjusted: tp.Optional[bool] = None, limit: tp.Optional[int] = None, params: tp.KwargsLike = None, delay: tp.Optional[float] = None, retries: tp.Optional[int] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Polygon. Args: symbol (str): Symbol. Supports the following APIs: * Stocks and equities * Currencies - symbol must have the prefix `C:` * Crypto - symbol must have the prefix `X:` client (polygon.rest.RESTClient): Client. See `PolygonData.resolve_client`. client_config (dict): Client config. See `PolygonData.resolve_client`. start (any): The start of the aggregate time window. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): The end of the aggregate time window. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. adjusted (str): Whether the results are adjusted for splits. By default, results are adjusted. Set this to False to get results that are NOT adjusted for splits. limit (int): Limits the number of base aggregates queried to create the aggregate results. Max 50000 and Default 5000. params (dict): Any additional query params. delay (float): Time to sleep after each request (in seconds). retries (int): The number of retries on failure to fetch data. show_progress (bool): Whether to show the progress bar. pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`. silence_warnings (bool): Whether to silence all warnings. For defaults, see `custom.polygon` in `vectorbtpro._settings.data`. !!! note If you're using a free plan that has an API rate limit of several requests per minute, make sure to set `delay` to a higher number, such as 12 (which makes 5 requests per minute). """ if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") adjusted = cls.resolve_custom_setting(adjusted, "adjusted") limit = cls.resolve_custom_setting(limit, "limit") params = cls.resolve_custom_setting(params, "params", merge=True) delay = cls.resolve_custom_setting(delay, "delay") retries = cls.resolve_custom_setting(retries, "retries") show_progress = cls.resolve_custom_setting(show_progress, "show_progress") pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True) if "bar_id" not in pbar_kwargs: pbar_kwargs["bar_id"] = "polygon" silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings") # Resolve the timeframe if not isinstance(timeframe, str): raise ValueError(f"Invalid timeframe: '{timeframe}'") split = dt.split_freq_str(timeframe) if split is None: raise ValueError(f"Invalid timeframe: '{timeframe}'") multiplier, unit = split if unit == "m": unit = "minute" elif unit == "h": unit = "hour" elif unit == "D": unit = "day" elif unit == "W": unit = "week" elif unit == "M": unit = "month" elif unit == "Q": unit = "quarter" elif unit == "Y": unit = "year" # Establish the timestamps if start is not None: start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")) else: start_ts = None if end is not None: end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")) else: end_ts = None prev_end_ts = None def _retry(method): @wraps(method) def retry_method(*args, **kwargs): for i in range(retries): try: return method(*args, **kwargs) except requests.exceptions.HTTPError as e: if isinstance(e, requests.exceptions.HTTPError) and e.response.status_code == 429: if not silence_warnings: warn(traceback.format_exc()) # Polygon.io API rate limit is per minute warn("Waiting 1 minute...") time.sleep(60) else: raise e except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: if i == retries - 1: raise e if not silence_warnings: warn(traceback.format_exc()) if delay is not None: time.sleep(delay) return retry_method def _postprocess(agg): return dict( o=agg.open, h=agg.high, l=agg.low, c=agg.close, v=agg.volume, vw=agg.vwap, t=agg.timestamp, n=agg.transactions, ) @_retry def _fetch(_start_ts, _limit): return list( map( _postprocess, client.get_aggs( ticker=symbol, multiplier=multiplier, timespan=unit, from_=_start_ts, to=end_ts, adjusted=adjusted, sort="asc", limit=_limit, params=params, raw=False, ), ) ) def _ts_to_str(ts: tp.Optional[int]) -> str: if ts is None: return "?" return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe) def _filter_func(d: tp.Dict, _prev_end_ts: tp.Optional[int] = None) -> bool: if start_ts is not None: if d["t"] < start_ts: return False if _prev_end_ts is not None: if d["t"] <= _prev_end_ts: return False if end_ts is not None: if d["t"] >= end_ts: return False return True # Iteratively collect the data data = [] try: with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar: pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts))) while True: # Fetch the klines for the next timeframe next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit) next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data)) # Update the timestamps and the progress bar if not len(next_data): break data += next_data if start_ts is None: start_ts = next_data[0]["t"] pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1]["t"]))) pbar.update() prev_end_ts = next_data[-1]["t"] if end_ts is not None and prev_end_ts >= end_ts: break if delay is not None: time.sleep(delay) # be kind to api except Exception as e: if not silence_warnings: warn(traceback.format_exc()) warn( f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. " "Use update() method to fetch missing data." ) df = pd.DataFrame(data) df = df[["t", "o", "h", "l", "c", "v", "n", "vw"]] df = df.rename( columns={ "t": "Open time", "o": "Open", "h": "High", "l": "Low", "c": "Close", "v": "Volume", "n": "Trade count", "vw": "VWAP", } ) df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True) del df["Open time"] if "Open" in df.columns: df["Open"] = df["Open"].astype(float) if "High" in df.columns: df["High"] = df["High"].astype(float) if "Low" in df.columns: df["Low"] = df["Low"].astype(float) if "Close" in df.columns: df["Close"] = df["Close"].astype(float) if "Volume" in df.columns: df["Volume"] = df["Volume"].astype(float) if "Trade count" in df.columns: df["Trade count"] = df["Trade count"].astype(int, errors="ignore") if "VWAP" in df.columns: df["VWAP"] = df["VWAP"].astype(float) return df, dict(tz=tz, freq=timeframe) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RandomOHLCData`.""" import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import broadcast_array_to from vectorbtpro.data import nb from vectorbtpro.data.custom.synthetic import SyntheticData from vectorbtpro.ohlcv import nb as ohlcv_nb from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.random_ import set_seed from vectorbtpro.utils.template import substitute_templates __all__ = [ "RandomOHLCData", ] __pdoc__ = {} class RandomOHLCData(SyntheticData): """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_1d_nb` and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.random_ohlc") @classmethod def generate_symbol( cls, symbol: tp.Symbol, index: tp.Index, n_ticks: tp.Optional[tp.ArrayLike] = None, start_value: tp.Optional[float] = None, mean: tp.Optional[float] = None, std: tp.Optional[float] = None, symmetric: tp.Optional[bool] = None, seed: tp.Optional[int] = None, jitted: tp.JittedOption = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.SymbolData: """Generate a symbol. Args: symbol (hashable): Symbol. index (pd.Index): Pandas index. n_ticks (int or array_like): Number of ticks per bar. Flexible argument. Can be a template with a context containing `symbol` and `index`. start_value (float): Value at time 0. Does not appear as the first value in the output data. mean (float): Drift, or mean of the percentage change. std (float): Standard deviation of the percentage change. symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones. seed (int): Seed to make output deterministic. jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. template_context (dict): Template context. For defaults, see `custom.random_ohlc` in `vectorbtpro._settings.data`. !!! note When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`. """ n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks") template_context = merge_dicts(dict(symbol=symbol, index=index), template_context) n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks") n_ticks = broadcast_array_to(n_ticks, len(index)) start_value = cls.resolve_custom_setting(start_value, "start_value") mean = cls.resolve_custom_setting(mean, "mean") std = cls.resolve_custom_setting(std, "std") symmetric = cls.resolve_custom_setting(symmetric, "symmetric") seed = cls.resolve_custom_setting(seed, "seed") if seed is not None: set_seed(seed) func = jit_reg.resolve_option(nb.generate_random_data_1d_nb, jitted) ticks = func(np.sum(n_ticks), start_value=start_value, mean=mean, std=std, symmetric=symmetric) func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted) out = func(ticks, n_ticks) return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"]) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) _ = fetch_kwargs.pop("start_value", None) start_value = self.data[symbol]["Open"].iloc[-1] fetch_kwargs["seed"] = None kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, start_value=start_value, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RandomData`.""" import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_1d_array from vectorbtpro.data import nb from vectorbtpro.data.custom.synthetic import SyntheticData from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.random_ import set_seed __all__ = [ "RandomData", ] __pdoc__ = {} class RandomData(SyntheticData): """`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_nb`.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.random") @classmethod def generate_key( cls, key: tp.Key, index: tp.Index, columns: tp.Union[tp.Hashable, tp.IndexLike] = None, start_value: tp.Optional[float] = None, mean: tp.Optional[float] = None, std: tp.Optional[float] = None, symmetric: tp.Optional[bool] = None, seed: tp.Optional[int] = None, jitted: tp.JittedOption = None, **kwargs, ) -> tp.KeyData: """Generate a feature or symbol. Args: key (hashable): Feature or symbol. index (pd.Index): Pandas index. columns (hashable or index_like): Column names. Provide a single value (hashable) to make a Series. start_value (float): Value at time 0. Does not appear as the first value in the output data. mean (float): Drift, or mean of the percentage change. std (float): Standard deviation of the percentage change. symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones. seed (int): Seed to make output deterministic. jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. For defaults, see `custom.random` in `vectorbtpro._settings.data`. !!! note When setting a seed, remember to pass a seed per feature/symbol using `vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally `vectorbtpro.data.base.key_dict`. """ if checks.is_hashable(columns): columns = [columns] make_series = True else: make_series = False if not isinstance(columns, pd.Index): columns = pd.Index(columns) start_value = cls.resolve_custom_setting(start_value, "start_value") mean = cls.resolve_custom_setting(mean, "mean") std = cls.resolve_custom_setting(std, "std") symmetric = cls.resolve_custom_setting(symmetric, "symmetric") seed = cls.resolve_custom_setting(seed, "seed") if seed is not None: set_seed(seed) func = jit_reg.resolve_option(nb.generate_random_data_nb, jitted) out = func( (len(index), len(columns)), start_value=to_1d_array(start_value), mean=to_1d_array(mean), std=to_1d_array(std), symmetric=to_1d_array(symmetric), ) if make_series: return pd.Series(out[:, 0], index=index, name=columns[0]) return pd.DataFrame(out, index=index, columns=columns) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: fetch_kwargs = self.select_fetch_kwargs(key) fetch_kwargs["start"] = self.select_last_index(key) _ = fetch_kwargs.pop("start_value", None) start_value = self.data[key].iloc[-2] fetch_kwargs["seed"] = None kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, start_value=start_value, **kwargs) return self.fetch_symbol(key, start_value=start_value, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RemoteData`.""" from vectorbtpro import _typing as tp from vectorbtpro.data.custom.custom import CustomData __all__ = [ "RemoteData", ] __pdoc__ = {} class RemoteData(CustomData): """Data class for fetching remote data.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.remote") # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `SQLData`.""" from typing import Iterator import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.db import DBData from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.config import merge_dicts try: if not tp.TYPE_CHECKING: raise ImportError from sqlalchemy import Engine as EngineT, Selectable as SelectableT, Table as TableT except ImportError: EngineT = "Engine" SelectableT = "Selectable" TableT = "Table" __all__ = [ "SQLData", ] __pdoc__ = {} SQLDataT = tp.TypeVar("SQLDataT", bound="SQLData") class SQLData(DBData): """Data class for fetching data from a database using SQLAlchemy. See https://www.sqlalchemy.org/ for the SQLAlchemy's API. See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for the read method. See `SQLData.pull` and `SQLData.fetch_key` for arguments. Usage: * Set up the engine settings globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.SQLData.set_engine_settings( ... engine_name="postgresql", ... populate_=True, ... engine="postgresql+psycopg2://...", ... engine_config=dict(), ... schema="public" ... ) ``` * Pull tables: ```pycon >>> data = vbt.SQLData.pull( ... ["TABLE1", "TABLE2"], ... engine="postgresql", ... start="2020-01-01", ... end="2021-01-01" ... ) ``` * Pull queries: ```pycon >>> data = vbt.SQLData.pull( ... ["SYMBOL1", "SYMBOL2"], ... query=vbt.key_dict({ ... "SYMBOL1": "SELECT * FROM TABLE1", ... "SYMBOL2": "SELECT * FROM TABLE2" ... }), ... engine="postgresql" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.sql") @classmethod def get_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> dict: """`SQLData.get_custom_settings` with `sub_path=engine_name`.""" if engine_name is not None: sub_path = "engines." + engine_name else: sub_path = None return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs) @classmethod def has_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool: """`SQLData.has_custom_settings` with `sub_path=engine_name`.""" if engine_name is not None: sub_path = "engines." + engine_name else: sub_path = None return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs) @classmethod def get_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any: """`SQLData.get_custom_setting` with `sub_path=engine_name`.""" if engine_name is not None: sub_path = "engines." + engine_name else: sub_path = None return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs) @classmethod def has_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool: """`SQLData.has_custom_setting` with `sub_path=engine_name`.""" if engine_name is not None: sub_path = "engines." + engine_name else: sub_path = None return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs) @classmethod def resolve_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any: """`SQLData.resolve_custom_setting` with `sub_path=engine_name`.""" if engine_name is not None: sub_path = "engines." + engine_name else: sub_path = None return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs) @classmethod def set_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> None: """`SQLData.set_custom_settings` with `sub_path=engine_name`.""" if engine_name is not None: sub_path = "engines." + engine_name else: sub_path = None cls.set_custom_settings(*args, sub_path=sub_path, **kwargs) @classmethod def resolve_engine( cls, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, return_meta: bool = False, **engine_config, ) -> tp.Union[EngineT, dict]: """Resolve the engine. Argument `engine` can be 1) an object of the type `sqlalchemy.engine.base.Engine`, 2) a URL of the engine as a string, which will be used to create an engine with `sqlalchemy.engine.create.create_engine` and `engine_config` passed as keyword arguments (you should not include `url` in the `engine_config`), or 3) an engine name, which is the name of a sub-config with engine settings under `custom.sql.engines` in `vectorbtpro._settings.data`. Such a sub-config can then contain the actual engine as an object or a URL. Argument `engine_name` can be provided instead of `engine`, or also together with `engine` to pull other settings from a sub-config. URLs can also be used as engine names, but not the other way around.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import create_engine if engine is None and engine_name is None: engine_name = cls.resolve_engine_setting(engine_name, "engine_name") if engine_name is not None: engine = cls.resolve_engine_setting(engine, "engine", engine_name=engine_name) if engine is None: raise ValueError("Must provide engine or URL (via engine argument)") else: engine = cls.resolve_engine_setting(engine, "engine") if engine is None: raise ValueError("Must provide engine or URL (via engine argument)") if isinstance(engine, str): engine_name = engine else: engine_name = None if engine_name is not None: if cls.has_engine_setting("engine", engine_name=engine_name, sub_path_only=True): engine = cls.get_engine_setting("engine", engine_name=engine_name, sub_path_only=True) has_engine_config = len(engine_config) > 0 engine_config = cls.resolve_engine_setting(engine_config, "engine_config", merge=True, engine_name=engine_name) if isinstance(engine, str): if engine.startswith("duckdb:"): assert_can_import("duckdb_engine") engine = create_engine(engine, **engine_config) should_dispose = True else: if has_engine_config: raise ValueError("Cannot apply engine_config to initialized created engine") should_dispose = False if return_meta: return dict( engine=engine, engine_name=engine_name, should_dispose=should_dispose, ) return engine @classmethod def list_schemas( cls, pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, dispose_engine: tp.Optional[bool] = None, **kwargs, ) -> tp.List[str]: """List all schemas. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`. Keyword arguments `**kwargs` are passed to `inspector.get_schema_names`. If `dispose_engine` is None, disposes the engine if it wasn't provided.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import inspect if engine_config is None: engine_config = {} engine_meta = cls.resolve_engine( engine=engine, engine_name=engine_name, return_meta=True, **engine_config, ) engine = engine_meta["engine"] should_dispose = engine_meta["should_dispose"] if dispose_engine is None: dispose_engine = should_dispose inspector = inspect(engine) all_schemas = inspector.get_schema_names(**kwargs) schemas = [] for schema in all_schemas: if pattern is not None: if not cls.key_match(schema, pattern, use_regex=use_regex): continue if schema == "information_schema": continue schemas.append(schema) if dispose_engine: engine.dispose() if sort: return sorted(dict.fromkeys(schemas)) return list(dict.fromkeys(schemas)) @classmethod def list_tables( cls, *, schema_pattern: tp.Optional[str] = None, table_pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, schema: tp.Optional[str] = None, incl_views: bool = True, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, dispose_engine: tp.Optional[bool] = None, **kwargs, ) -> tp.List[str]: """List all tables and views. If `schema` is None, searches for all schema names in the database and prefixes each table with the respective schema name (unless there's only one schema "main"). If `schema` is False, sets the schema to None. If `schema` is provided, returns the tables corresponding to this schema without a prefix. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against `schema_pattern` and each table against `table_pattern`. Keyword arguments `**kwargs` are passed to `inspector.get_table_names`. If `dispose_engine` is None, disposes the engine if it wasn't provided.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import inspect if engine_config is None: engine_config = {} engine_meta = cls.resolve_engine( engine=engine, engine_name=engine_name, return_meta=True, **engine_config, ) engine = engine_meta["engine"] engine_name = engine_meta["engine_name"] should_dispose = engine_meta["should_dispose"] if dispose_engine is None: dispose_engine = should_dispose schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name) if schema is None: schemas = cls.list_schemas( pattern=schema_pattern, use_regex=use_regex, sort=sort, engine=engine, engine_name=engine_name, **kwargs, ) if len(schemas) == 0: schemas = [None] prefix_schema = False elif len(schemas) == 1 and schemas[0] == "main": prefix_schema = False else: prefix_schema = True elif schema is False: schemas = [None] prefix_schema = False else: schemas = [schema] prefix_schema = False inspector = inspect(engine) tables = [] for schema in schemas: all_tables = inspector.get_table_names(schema, **kwargs) if incl_views: try: all_tables += inspector.get_view_names(schema, **kwargs) except NotImplementedError as e: pass try: all_tables += inspector.get_materialized_view_names(schema, **kwargs) except NotImplementedError as e: pass for table in all_tables: if table_pattern is not None: if not cls.key_match(table, table_pattern, use_regex=use_regex): continue if prefix_schema and schema is not None: table = str(schema) + ":" + table tables.append(table) if dispose_engine: engine.dispose() if sort: return sorted(dict.fromkeys(tables)) return list(dict.fromkeys(tables)) @classmethod def has_schema( cls, schema: str, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, ) -> bool: """Check whether the database has a schema.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import inspect if engine_config is None: engine_config = {} engine = cls.resolve_engine( engine=engine, engine_name=engine_name, **engine_config, ) return inspect(engine).has_schema(schema) @classmethod def create_schema( cls, schema: str, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, ) -> None: """Create a schema if it doesn't exist yet.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy.schema import CreateSchema if engine_config is None: engine_config = {} engine = cls.resolve_engine( engine=engine, engine_name=engine_name, **engine_config, ) if not cls.has_schema(schema, engine=engine, engine_name=engine_name): with engine.connect() as connection: connection.execute(CreateSchema(schema)) connection.commit() @classmethod def has_table( cls, table: str, schema: tp.Optional[str] = None, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, ) -> bool: """Check whether the database has a table.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import inspect if engine_config is None: engine_config = {} engine = cls.resolve_engine( engine=engine, engine_name=engine_name, **engine_config, ) return inspect(engine).has_table(table, schema=schema) @classmethod def get_table_relation( cls, table: str, schema: tp.Optional[str] = None, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, ) -> TableT: """Get table relation.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import MetaData if engine_config is None: engine_config = {} engine = cls.resolve_engine( engine=engine, engine_name=engine_name, **engine_config, ) schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name) metadata_obj = MetaData() metadata_obj.reflect(bind=engine, schema=schema, only=[table], views=True) if schema is not None and schema + "." + table in metadata_obj.tables: return metadata_obj.tables[schema + "." + table] return metadata_obj.tables[table] @classmethod def get_last_row_number( cls, table: str, schema: tp.Optional[str] = None, row_number_column: tp.Optional[str] = None, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, ) -> TableT: """Get last row number.""" if engine_config is None: engine_config = {} engine_meta = cls.resolve_engine( engine=engine, engine_name=engine_name, return_meta=True, **engine_config, ) engine = engine_meta["engine"] engine_name = engine_meta["engine_name"] row_number_column = cls.resolve_engine_setting( row_number_column, "row_number_column", engine_name=engine_name, ) table_relation = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name) table_column_names = [] for column in table_relation.columns: table_column_names.append(column.name) if row_number_column not in table_column_names: raise ValueError(f"Row number column '{row_number_column}' not found") query = ( table_relation.select() .with_only_columns(table_relation.columns.get(row_number_column)) .order_by(table_relation.columns.get(row_number_column).desc()) .limit(1) ) with engine.connect() as connection: results = connection.execute(query) last_row_number = results.first()[0] connection.commit() return last_row_number @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, schema: tp.Optional[str] = None, list_tables_kwargs: tp.KwargsLike = None, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, ) -> tp.Kwargs: keys_meta = DBData.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) if keys_meta["keys"] is None: if cls.has_key_dict(schema): raise ValueError("Cannot populate keys if schema is defined per key") if cls.has_key_dict(list_tables_kwargs): raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key") if cls.has_key_dict(engine): raise ValueError("Cannot populate keys if engine is defined per key") if cls.has_key_dict(engine_name): raise ValueError("Cannot populate keys if engine_name is defined per key") if cls.has_key_dict(engine_config): raise ValueError("Cannot populate keys if engine_config is defined per key") if list_tables_kwargs is None: list_tables_kwargs = {} keys_meta["keys"] = cls.list_tables( schema=schema, engine=engine, engine_name=engine_name, engine_config=engine_config, **list_tables_kwargs, ) return keys_meta @classmethod def pull( cls: tp.Type[SQLDataT], keys: tp.Union[tp.MaybeKeys] = None, *, keys_are_features: tp.Optional[bool] = None, features: tp.Union[tp.MaybeFeatures] = None, symbols: tp.Union[tp.MaybeSymbols] = None, schema: tp.Optional[str] = None, list_tables_kwargs: tp.KwargsLike = None, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, dispose_engine: tp.Optional[bool] = None, share_engine: tp.Optional[bool] = None, **kwargs, ) -> SQLDataT: """Override `vectorbtpro.data.base.Data.pull` to resolve and share the engine among the keys and use the table names available in the database in case no keys were provided.""" if share_engine is None: if ( not cls.has_key_dict(engine) and not cls.has_key_dict(engine_name) and not cls.has_key_dict(engine_config) ): share_engine = True else: share_engine = False if share_engine: if engine_config is None: engine_config = {} engine_meta = cls.resolve_engine( engine=engine, engine_name=engine_name, return_meta=True, **engine_config, ) engine = engine_meta["engine"] engine_name = engine_meta["engine_name"] should_dispose = engine_meta["should_dispose"] if dispose_engine is None: dispose_engine = should_dispose else: engine_name = None keys_meta = cls.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, schema=schema, list_tables_kwargs=list_tables_kwargs, engine=engine, engine_name=engine_name, engine_config=engine_config, ) keys = keys_meta["keys"] keys_are_features = keys_meta["keys_are_features"] outputs = super(DBData, cls).pull( keys, keys_are_features=keys_are_features, schema=schema, engine=engine, engine_name=engine_name, engine_config=engine_config, dispose_engine=False if share_engine else dispose_engine, **kwargs, ) if share_engine and dispose_engine: engine.dispose() return outputs @classmethod def fetch_key( cls, key: str, table: tp.Union[None, str, TableT] = None, schema: tp.Optional[str] = None, query: tp.Union[None, str, SelectableT] = None, engine: tp.Union[None, str, EngineT] = None, engine_name: tp.Optional[str] = None, engine_config: tp.KwargsLike = None, dispose_engine: tp.Optional[bool] = None, start: tp.Optional[tp.Any] = None, end: tp.Optional[tp.Any] = None, align_dates: tp.Optional[bool] = None, parse_dates: tp.Union[None, bool, tp.List[tp.IntStr], tp.Dict[tp.IntStr, tp.Any]] = None, to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None, tz: tp.TimezoneLike = None, start_row: tp.Optional[int] = None, end_row: tp.Optional[int] = None, keep_row_number: tp.Optional[bool] = None, row_number_column: tp.Optional[str] = None, index_col: tp.Union[None, bool, tp.MaybeList[tp.IntStr]] = None, columns: tp.Optional[tp.MaybeList[tp.IntStr]] = None, dtype: tp.Union[None, tp.DTypeLike, tp.Dict[tp.IntStr, tp.DTypeLike]] = None, chunksize: tp.Optional[int] = None, chunk_func: tp.Optional[tp.Callable] = None, squeeze: tp.Optional[bool] = None, **read_sql_kwargs, ) -> tp.KeyData: """Fetch a feature or symbol from a SQL database. Can use a table name (which defaults to the key) or a custom query. Args: key (str): Feature or symbol. If `table` and `query` are both None, becomes the table name. Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored. table (str or Table): Table name or actual object. Cannot be used together with `query`. schema (str): Schema. Cannot be used together with `query`. query (str or Selectable): Custom query. Cannot be used together with `table` and `schema`. engine (str or object): See `SQLData.resolve_engine`. engine_name (str): See `SQLData.resolve_engine`. engine_config (dict): See `SQLData.resolve_engine`. dispose_engine (bool): See `SQLData.resolve_engine`. start (any): Start datetime (if datetime index) or any other start value. Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True and the index is a datetime index. Otherwise, you must ensure the correct type is provided. If the index is a multi-index, start value must be a tuple. Cannot be used together with `query`. Include the condition into the query. end (any): End datetime (if datetime index) or any other end value. Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True and the index is a datetime index. Otherwise, you must ensure the correct type is provided. If the index is a multi-index, end value must be a tuple. Cannot be used together with `query`. Include the condition into the query. align_dates (bool): Whether to align `start` and `end` to the timezone of the index. Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index. parse_dates (bool, list, or dict): Whether to parse dates and how to do it. If `query` is not used, will get mapped into column names. Otherwise, usage of integers is not allowed and column names directly must be used. If enabled, will also try to parse the datetime columns that couldn't be parsed by Pandas after the object has been fetched. For dict format, see `pd.read_sql_query`. to_utc (bool, str, or sequence of str): See `SQLData.prepare_dt`. tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. start_row (int): Start row. Table must contain the column defined in `row_number_column`. Cannot be used together with `query`. Include the condition into the query. end_row (int): End row. Table must contain the column defined in `row_number_column`. Cannot be used together with `query`. Include the condition into the query. keep_row_number (bool): Whether to return the column defined in `row_number_column`. row_number_column (str): Name of the column with row numbers. index_col (int, str, or list): One or more columns that should become the index. If `query` is not used, will get mapped into column names. Otherwise, usage of integers is not allowed and column names directly must be used. columns (int, str, or list): One or more columns to select. Will get mapped into column names. Cannot be used together with `query`. dtype (dtype_like or dict): Data type of each column. If `query` is not used, will get mapped into column names. Otherwise, usage of integers is not allowed and column names directly must be used. For dict format, see `pd.read_sql_query`. chunksize (int): See `pd.read_sql_query`. chunk_func (callable): Function to select and concatenate chunks from `Iterator`. Gets called only if `chunksize` is set. squeeze (int): Whether to squeeze a DataFrame with one column into a Series. **read_sql_kwargs: Other keyword arguments passed to `pd.read_sql_query`. See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for other arguments. For defaults, see `custom.sql` in `vectorbtpro._settings.data`. Global settings can be provided per engine name using the `engines` dictionary. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from sqlalchemy import Selectable, Select, FromClause, and_, text if engine_config is None: engine_config = {} engine_meta = cls.resolve_engine( engine=engine, engine_name=engine_name, return_meta=True, **engine_config, ) engine = engine_meta["engine"] engine_name = engine_meta["engine_name"] should_dispose = engine_meta["should_dispose"] if dispose_engine is None: dispose_engine = should_dispose if table is not None and query is not None: raise ValueError("Must provide either table name or query, not both") if schema is not None and query is not None: raise ValueError("Schema cannot be applied to custom queries") if table is None and query is None: if ":" in key: schema, table = key.split(":") else: table = key start = cls.resolve_engine_setting(start, "start", engine_name=engine_name) end = cls.resolve_engine_setting(end, "end", engine_name=engine_name) align_dates = cls.resolve_engine_setting(align_dates, "align_dates", engine_name=engine_name) parse_dates = cls.resolve_engine_setting(parse_dates, "parse_dates", engine_name=engine_name) to_utc = cls.resolve_engine_setting(to_utc, "to_utc", engine_name=engine_name) tz = cls.resolve_engine_setting(tz, "tz", engine_name=engine_name) start_row = cls.resolve_engine_setting(start_row, "start_row", engine_name=engine_name) end_row = cls.resolve_engine_setting(end_row, "end_row", engine_name=engine_name) keep_row_number = cls.resolve_engine_setting(keep_row_number, "keep_row_number", engine_name=engine_name) row_number_column = cls.resolve_engine_setting(row_number_column, "row_number_column", engine_name=engine_name) index_col = cls.resolve_engine_setting(index_col, "index_col", engine_name=engine_name) columns = cls.resolve_engine_setting(columns, "columns", engine_name=engine_name) dtype = cls.resolve_engine_setting(dtype, "dtype", engine_name=engine_name) chunksize = cls.resolve_engine_setting(chunksize, "chunksize", engine_name=engine_name) chunk_func = cls.resolve_engine_setting(chunk_func, "chunk_func", engine_name=engine_name) squeeze = cls.resolve_engine_setting(squeeze, "squeeze", engine_name=engine_name) read_sql_kwargs = cls.resolve_engine_setting( read_sql_kwargs, "read_sql_kwargs", merge=True, engine_name=engine_name ) if query is None or isinstance(query, (Selectable, FromClause)): if query is None: if isinstance(table, str): table = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name) else: table = query table_column_names = [] for column in table.columns: table_column_names.append(column.name) def _resolve_columns(c): if checks.is_int(c): c = table_column_names[int(c)] elif not isinstance(c, str): new_c = [] for _c in c: if checks.is_int(_c): new_c.append(table_column_names[int(_c)]) else: if _c not in table_column_names: for __c in table_column_names: if _c.lower() == __c.lower(): _c = __c break new_c.append(_c) c = new_c else: if c not in table_column_names: for _c in table_column_names: if c.lower() == _c.lower(): return _c return c if index_col is False: index_col = None if index_col is not None: index_col = _resolve_columns(index_col) if isinstance(index_col, str): index_col = [index_col] if columns is not None: columns = _resolve_columns(columns) if isinstance(columns, str): columns = [columns] if parse_dates is not None: if not isinstance(parse_dates, bool): if isinstance(parse_dates, dict): parse_dates = dict(zip(_resolve_columns(parse_dates.keys()), parse_dates.values())) else: parse_dates = _resolve_columns(parse_dates) if isinstance(parse_dates, str): parse_dates = [parse_dates] if dtype is not None: if isinstance(dtype, dict): dtype = dict(zip(_resolve_columns(dtype.keys()), dtype.values())) if not isinstance(table, Select): query = table.select() else: query = table if index_col is not None: for col in index_col: query = query.order_by(col) if index_col is not None and columns is not None: pre_columns = [] for col in index_col: if col not in columns: pre_columns.append(col) columns = pre_columns + columns if keep_row_number and columns is not None: if row_number_column in table_column_names and row_number_column not in columns: columns = [row_number_column] + columns elif not keep_row_number and columns is None: if row_number_column in table_column_names: columns = [col for col in table_column_names if col != row_number_column] if columns is not None: query = query.with_only_columns(*[table.columns.get(c) for c in columns]) def _to_native_type(x): if checks.is_np_scalar(x): return x.item() return x if start_row is not None or end_row is not None: if start is not None or end is not None: raise ValueError("Can either filter by row numbers or by index, not both") _row_number_column = table.columns.get(row_number_column) if _row_number_column is None: raise ValueError(f"Row number column '{row_number_column}' not found") and_list = [] if start_row is not None: and_list.append(_row_number_column >= _to_native_type(start_row)) if end_row is not None: and_list.append(_row_number_column < _to_native_type(end_row)) query = query.where(and_(*and_list)) if start is not None or end is not None: if index_col is None: raise ValueError("Must provide index column for filtering by start and end") if align_dates: first_obj = pd.read_sql_query( query.limit(1), engine.connect(), index_col=index_col, parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted dtype=dtype, chunksize=None, **read_sql_kwargs, ) first_obj = cls.prepare_dt( first_obj, parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates, to_utc=False, ) if isinstance(first_obj.index, pd.DatetimeIndex): if tz is None: tz = first_obj.index.tz if first_obj.index.tz is not None: if start is not None: start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=first_obj.index.tz) if end is not None: end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=first_obj.index.tz) else: if start is not None: if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc) ): start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc") start = dt.to_naive_datetime(start) else: start = dt.to_naive_datetime(start, tz=tz) if end is not None: if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc) ): end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc") end = dt.to_naive_datetime(end) else: end = dt.to_naive_datetime(end, tz=tz) and_list = [] if start is not None: if len(index_col) > 1: if not isinstance(start, tuple): raise TypeError("Start must be a tuple if the index is a multi-index") if len(start) != len(index_col): raise ValueError("Start tuple must match the number of levels in the multi-index") for i in range(len(index_col)): index_column = table.columns.get(index_col[i]) and_list.append(index_column >= _to_native_type(start[i])) else: index_column = table.columns.get(index_col[0]) and_list.append(index_column >= _to_native_type(start)) if end is not None: if len(index_col) > 1: if not isinstance(end, tuple): raise TypeError("End must be a tuple if the index is a multi-index") if len(end) != len(index_col): raise ValueError("End tuple must match the number of levels in the multi-index") for i in range(len(index_col)): index_column = table.columns.get(index_col[i]) and_list.append(index_column < _to_native_type(end[i])) else: index_column = table.columns.get(index_col[0]) and_list.append(index_column < _to_native_type(end)) query = query.where(and_(*and_list)) else: def _check_columns(c, arg_name): if checks.is_int(c): raise ValueError(f"Must provide column as a string for '{arg_name}'") elif not isinstance(c, str): for _c in c: if checks.is_int(_c): raise ValueError(f"Must provide each column as a string for '{arg_name}'") if start is not None: raise ValueError("Start cannot be applied to custom queries") if end is not None: raise ValueError("End cannot be applied to custom queries") if start_row is not None: raise ValueError("Start row cannot be applied to custom queries") if end_row is not None: raise ValueError("End row cannot be applied to custom queries") if index_col is False: index_col = None if index_col is not None: _check_columns(index_col, "index_col") if isinstance(index_col, str): index_col = [index_col] if columns is not None: raise ValueError("Columns cannot be applied to custom queries") if parse_dates is not None: if not isinstance(parse_dates, bool): if isinstance(parse_dates, dict): _check_columns(parse_dates.keys(), "parse_dates") else: _check_columns(parse_dates, "parse_dates") if isinstance(parse_dates, str): parse_dates = [parse_dates] if dtype is not None: _check_columns(dtype.keys(), "dtype") if isinstance(query, str): query = text(query) obj = pd.read_sql_query( query, engine.connect(), index_col=index_col, parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted dtype=dtype, chunksize=chunksize, **read_sql_kwargs, ) if isinstance(obj, Iterator): if chunk_func is None: obj = pd.concat(list(obj), axis=0) else: obj = chunk_func(obj) obj = cls.prepare_dt( obj, parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates, to_utc=to_utc, ) if not isinstance(obj.index, pd.MultiIndex): if obj.index.name == "index": obj.index.name = None if isinstance(obj.index, pd.DatetimeIndex) and tz is None: tz = obj.index.tz if isinstance(obj, pd.DataFrame) and squeeze: obj = obj.squeeze("columns") if isinstance(obj, pd.Series) and obj.name == "0": obj.name = None if dispose_engine: engine.dispose() if keep_row_number: return obj, dict(tz=tz, row_number_column=row_number_column) return obj, dict(tz=tz) @classmethod def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData: """Fetch the table of a feature. Uses `SQLData.fetch_key`.""" return cls.fetch_key(feature, **kwargs) @classmethod def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData: """Fetch the table for a symbol. Uses `SQLData.fetch_key`.""" return cls.fetch_key(symbol, **kwargs) def update_key( self, key: str, from_last_row: tp.Optional[bool] = None, from_last_index: tp.Optional[bool] = None, **kwargs, ) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) returned_kwargs = self.select_returned_kwargs(key) pre_kwargs = merge_dicts(fetch_kwargs, kwargs) if from_last_row is None: if pre_kwargs.get("query", None) is not None: from_last_row = False elif from_last_index is True: from_last_row = False elif pre_kwargs.get("start", None) is not None or pre_kwargs.get("end", None) is not None: from_last_row = False elif "row_number_column" not in returned_kwargs: from_last_row = False elif returned_kwargs["row_number_column"] not in self.wrapper.columns: from_last_row = False else: from_last_row = True if from_last_index is None: if pre_kwargs.get("query", None) is not None: from_last_index = False elif from_last_row is True: from_last_index = False elif pre_kwargs.get("start_row", None) is not None or pre_kwargs.get("end_row", None) is not None: from_last_index = False else: from_last_index = True if from_last_row: if "row_number_column" not in returned_kwargs: raise ValueError("Argument row_number_column must be in returned_kwargs for from_last_row") row_number_column = returned_kwargs["row_number_column"] fetch_kwargs["start_row"] = self.data[key][row_number_column].iloc[-1] if from_last_index: fetch_kwargs["start"] = self.select_last_index(key) kwargs = merge_dicts(fetch_kwargs, kwargs) if self.feature_oriented: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: str, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `SQLData.update_key`.""" return self.update_key(feature, **kwargs) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `SQLData.update_key`.""" return self.update_key(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `SyntheticData`.""" from vectorbtpro import _typing as tp from vectorbtpro.data.custom.custom import CustomData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts __all__ = [ "SyntheticData", ] __pdoc__ = {} class SyntheticData(CustomData): """Data class for fetching synthetic data. Exposes an abstract class method `SyntheticData.generate_symbol`. Everything else is taken care of.""" _settings_path: tp.SettingsPath = dict(custom="data.custom.synthetic") @classmethod def generate_key(cls, key: tp.Key, index: tp.Index, key_is_feature: bool = False, **kwargs) -> tp.KeyData: """Abstract method to generate data of a feature or symbol.""" raise NotImplementedError @classmethod def generate_feature(cls, feature: tp.Feature, index: tp.Index, **kwargs) -> tp.FeatureData: """Abstract method to generate data of a feature. Uses `SyntheticData.generate_key` with `key_is_feature=True`.""" return cls.generate_key(feature, index, key_is_feature=True, **kwargs) @classmethod def generate_symbol(cls, symbol: tp.Symbol, index: tp.Index, **kwargs) -> tp.SymbolData: """Abstract method to generate data for a symbol. Uses `SyntheticData.generate_key` with `key_is_feature=False`.""" return cls.generate_key(symbol, index, key_is_feature=False, **kwargs) @classmethod def fetch_key( cls, key: tp.Symbol, key_is_feature: bool = False, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, periods: tp.Optional[int] = None, timeframe: tp.Optional[tp.FrequencyLike] = None, tz: tp.TimezoneLike = None, normalize: tp.Optional[bool] = None, inclusive: tp.Optional[str] = None, **kwargs, ) -> tp.KeyData: """Generate data of a feature or symbol. Generates datetime index using `vectorbtpro.utils.datetime_.date_range` and passes it to `SyntheticData.generate_key` to fill the Series/DataFrame with generated data. For defaults, see `custom.synthetic` in `vectorbtpro._settings.data`.""" start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") normalize = cls.resolve_custom_setting(normalize, "normalize") inclusive = cls.resolve_custom_setting(inclusive, "inclusive") index = dt.date_range( start=start, end=end, periods=periods, freq=timeframe, normalize=normalize, inclusive=inclusive, ) if tz is None: tz = index.tz if len(index) == 0: raise ValueError("Date range is empty") if key_is_feature: return cls.generate_feature(key, index, **kwargs), dict(tz=tz, freq=timeframe) return cls.generate_symbol(key, index, **kwargs), dict(tz=tz, freq=timeframe) @classmethod def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Generate data of a feature. Uses `SyntheticData.fetch_key` with `key_is_feature=True`.""" return cls.fetch_key(feature, key_is_feature=True, **kwargs) @classmethod def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Generate data for a symbol. Uses `SyntheticData.fetch_key` with `key_is_feature=False`.""" return cls.fetch_key(symbol, key_is_feature=False, **kwargs) def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData: """Update data of a feature or symbol.""" fetch_kwargs = self.select_fetch_kwargs(key) fetch_kwargs["start"] = self.select_last_index(key) kwargs = merge_dicts(fetch_kwargs, kwargs) if key_is_feature: return self.fetch_feature(key, **kwargs) return self.fetch_symbol(key, **kwargs) def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData: """Update data of a feature. Uses `SyntheticData.update_key` with `key_is_feature=True`.""" return self.update_key(feature, key_is_feature=True, **kwargs) def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData: """Update data for a symbol. Uses `SyntheticData.update_key` with `key_is_feature=False`.""" return self.update_key(symbol, key_is_feature=False, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `TVData`.""" import datetime import json import math import random import re import string import time import pandas as pd import requests from websocket import WebSocket from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts, Configured from vectorbtpro.utils.pbar import ProgressBar from vectorbtpro.utils.template import CustomTemplate __all__ = [ "TVClient", "TVData", ] SIGNIN_URL = "https://www.tradingview.com/accounts/signin/" """Sign-in URL.""" SEARCH_URL = ( "https://symbol-search.tradingview.com/symbol_search/v3/?" "text={text}&" "start={start}&" "hl=1&" "exchange={exchange}&" "lang=en&" "search_type=undefined&" "domain=production&" "sort_by_country=US" ) """Symbol search URL.""" SCAN_URL = "https://scanner.tradingview.com/{market}/scan" """Market scanner URL.""" ORIGIN_URL = "https://data.tradingview.com" """Origin URL.""" REFERER_URL = "https://www.tradingview.com" """Referer URL.""" WS_URL = "wss://data.tradingview.com/socket.io/websocket" """Websocket URL.""" PRO_WS_URL = "wss://prodata.tradingview.com/socket.io/websocket" """Websocket URL (Pro).""" WS_TIMEOUT = 5 """Websocket timeout.""" MARKET_LIST = [ "america", "argentina", "australia", "austria", "bahrain", "bangladesh", "belgium", "brazil", "canada", "chile", "china", "colombia", "cyprus", "czech", "denmark", "egypt", "estonia", "euronext", "finland", "france", "germany", "greece", "hongkong", "hungary", "iceland", "india", "indonesia", "israel", "italy", "japan", "kenya", "korea", "ksa", "kuwait", "latvia", "lithuania", "luxembourg", "malaysia", "mexico", "morocco", "netherlands", "newzealand", "nigeria", "norway", "pakistan", "peru", "philippines", "poland", "portugal", "qatar", "romania", "rsa", "russia", "serbia", "singapore", "slovakia", "spain", "srilanka", "sweden", "switzerland", "taiwan", "thailand", "tunisia", "turkey", "uae", "uk", "venezuela", "vietnam", ] """List of markets supported by the market scanner (list may be incomplete).""" FIELD_LIST = [ "name", "description", "logoid", "update_mode", "type", "typespecs", "close", "pricescale", "minmov", "fractional", "minmove2", "currency", "change", "change_abs", "Recommend.All", "volume", "Value.Traded", "market_cap_basic", "fundamental_currency_code", "Perf.1Y.MarketCap", "price_earnings_ttm", "earnings_per_share_basic_ttm", "number_of_employees_fy", "sector", "market", ] """List of fields supported by the market scanner (list may be incomplete).""" USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36" """User agent.""" class TVClient(Configured): """Client for TradingView.""" def __init__( self, username: tp.Optional[str] = None, password: tp.Optional[str] = None, auth_token: tp.Optional[str] = None, **kwargs, ) -> None: """Client for TradingView.""" Configured.__init__( self, username=username, password=password, auth_token=auth_token, **kwargs, ) if auth_token is None: auth_token = self.auth(username, password) elif username is not None or password is not None: raise ValueError("Must provide either username and password, or auth_token") self._auth_token = auth_token self._ws = None self._session = self.generate_session() self._chart_session = self.generate_chart_session() @property def auth_token(self) -> str: """Authentication token.""" return self._auth_token @property def ws(self) -> WebSocket: """Instance of `websocket.Websocket`.""" return self._ws @property def session(self) -> str: """Session.""" return self._session @property def chart_session(self) -> str: """Chart session.""" return self._chart_session @classmethod def auth( cls, username: tp.Optional[str] = None, password: tp.Optional[str] = None, ) -> str: """Authenticate.""" if username is not None and password is not None: data = {"username": username, "password": password, "remember": "on"} headers = {"Referer": REFERER_URL, "User-Agent": USER_AGENT} response = requests.post(url=SIGNIN_URL, data=data, headers=headers) response.raise_for_status() json = response.json() if "user" not in json or "auth_token" not in json["user"]: raise ValueError(json) return json["user"]["auth_token"] if username is not None or password is not None: raise ValueError("Must provide both username and password") return "unauthorized_user_token" @classmethod def generate_session(cls) -> str: """Generate session.""" stringLength = 12 letters = string.ascii_lowercase random_string = "".join(random.choice(letters) for _ in range(stringLength)) return "qs_" + random_string @classmethod def generate_chart_session(cls) -> str: """Generate chart session.""" stringLength = 12 letters = string.ascii_lowercase random_string = "".join(random.choice(letters) for _ in range(stringLength)) return "cs_" + random_string def create_connection(self, pro_data: bool = True) -> None: """Create a websocket connection.""" from websocket import create_connection if pro_data: self._ws = create_connection( PRO_WS_URL, headers=json.dumps({"Origin": ORIGIN_URL}), timeout=WS_TIMEOUT, ) else: self._ws = create_connection( WS_URL, headers=json.dumps({"Origin": ORIGIN_URL}), timeout=WS_TIMEOUT, ) @classmethod def filter_raw_message(cls, text) -> tp.Tuple[str, str]: """Filter raw message.""" found = re.search('"m":"(.+?)",', text).group(1) found2 = re.search('"p":(.+?"}"])}', text).group(1) return found, found2 @classmethod def prepend_header(cls, st: str) -> str: """Prepend a header.""" return "~m~" + str(len(st)) + "~m~" + st @classmethod def construct_message(cls, func: str, param_list: tp.List[str]) -> str: """Construct a message.""" return json.dumps({"m": func, "p": param_list}, separators=(",", ":")) def create_message(self, func: str, param_list: tp.List[str]) -> str: """Create a message.""" return self.prepend_header(self.construct_message(func, param_list)) def send_message(self, func: str, param_list: tp.List[str]) -> None: """Send a message.""" m = self.create_message(func, param_list) self.ws.send(m) @classmethod def convert_raw_data(cls, raw_data: str, symbol: str) -> pd.DataFrame: """Process raw data into a DataFrame.""" search_result = re.search(r'"s":\[(.+?)\}\]', raw_data) if search_result is None: raise ValueError("Couldn't parse data returned by TradingView: {}".format(raw_data)) out = search_result.group(1) x = out.split(',{"') data = list() volume_data = True for xi in x: xi = re.split(r"\[|:|,|\]", xi) ts = datetime.datetime.utcfromtimestamp(float(xi[4])) row = [ts] for i in range(5, 10): # skip converting volume data if does not exists if not volume_data and i == 9: row.append(0.0) continue try: row.append(float(xi[i])) except ValueError: volume_data = False row.append(0.0) data.append(row) data = pd.DataFrame(data, columns=["datetime", "open", "high", "low", "close", "volume"]) data = data.set_index("datetime") data.insert(0, "symbol", value=symbol) return data @classmethod def format_symbol(cls, symbol: str, exchange: str, fut_contract: tp.Optional[int] = None) -> str: """Format a symbol.""" if ":" in symbol: pass elif fut_contract is None: symbol = f"{exchange}:{symbol}" elif isinstance(fut_contract, int): symbol = f"{exchange}:{symbol}{fut_contract}!" else: raise ValueError(f"Invalid fut_contract: '{fut_contract}'") return symbol def get_hist( self, symbol: str, exchange: str = "NSE", interval: str = "1D", fut_contract: tp.Optional[int] = None, adjustment: str = "splits", extended_session: bool = False, pro_data: bool = True, limit: int = 20000, return_raw: bool = False, ) -> tp.Union[str, tp.Frame]: """Get historical data.""" symbol = self.format_symbol(symbol=symbol, exchange=exchange, fut_contract=fut_contract) backadjustment = False if symbol.endswith("!A"): backadjustment = True symbol = symbol.replace("!A", "!") self.create_connection(pro_data=pro_data) self.send_message("set_auth_token", [self.auth_token]) self.send_message("chart_create_session", [self.chart_session, ""]) self.send_message("quote_create_session", [self.session]) self.send_message( "quote_set_fields", [ self.session, "ch", "chp", "current_session", "description", "local_description", "language", "exchange", "fractional", "is_tradable", "lp", "lp_time", "minmov", "minmove2", "original_name", "pricescale", "pro_name", "short_name", "type", "update_mode", "volume", "currency_code", "rchp", "rtc", ], ) self.send_message("quote_add_symbols", [self.session, symbol, {"flags": ["force_permission"]}]) self.send_message("quote_fast_symbols", [self.session, symbol]) self.send_message( "resolve_symbol", [ self.chart_session, "symbol_1", '={"symbol":"' + symbol + '","adjustment":"' + adjustment + ("" if not backadjustment else '","backadjustment":"default') + '","session":' + ('"regular"' if not extended_session else '"extended"') + "}", ], ) self.send_message("create_series", [self.chart_session, "s1", "s1", "symbol_1", interval, limit]) self.send_message("switch_timezone", [self.chart_session, "exchange"]) raw_data = "" while True: try: result = self.ws.recv() raw_data += result + "\n" except Exception as e: break if "series_completed" in result: break if return_raw: return raw_data return self.convert_raw_data(raw_data, symbol) @classmethod def search_symbol( cls, text: tp.Optional[str] = None, exchange: tp.Optional[str] = None, pages: tp.Optional[int] = None, delay: tp.Optional[int] = None, retries: int = 3, show_progress: bool = True, pbar_kwargs: tp.KwargsLike = None, ) -> tp.List[dict]: """Search for a symbol.""" if text is None: text = "" if exchange is None: exchange = "" if pbar_kwargs is None: pbar_kwargs = {} symbols_list = [] pbar = None pages_fetched = 0 while True: for i in range(retries): try: url = SEARCH_URL.format(text=text, exchange=exchange.upper(), start=len(symbols_list)) headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT} resp = requests.get(url, headers=headers) symbols_data = json.loads(resp.text.replace("", "").replace("", "")) break except json.JSONDecodeError as e: if i == retries - 1: raise e if delay is not None: time.sleep(delay) symbols_remaining = symbols_data.get("symbols_remaining", 0) new_symbols = symbols_data.get("symbols", []) symbols_list.extend(new_symbols) if pages is None and symbols_remaining > 0: show_pbar = True elif pages is not None and pages > 1: show_pbar = True else: show_pbar = False if pbar is None and show_pbar: if pages is not None: total = pages else: total = math.ceil((len(new_symbols) + symbols_remaining) / len(new_symbols)) pbar = ProgressBar( total=total, show_progress=show_progress, **pbar_kwargs, ) pbar.enter() if pbar is not None: max_symbols = len(symbols_list) + symbols_remaining if pages is not None: max_symbols = min(max_symbols, pages * len(new_symbols)) pbar.set_description(dict(symbols="%d/%d" % (len(symbols_list), max_symbols))) pbar.update() if symbols_remaining == 0: break pages_fetched += 1 if pages is not None and pages_fetched >= pages: break if delay is not None: time.sleep(delay) if pbar is not None: pbar.exit() return symbols_list @classmethod def scan_symbols(cls, market: tp.Optional[str] = None, **kwargs) -> tp.List[dict]: """Scan symbols in a region/market.""" if market is None: market = "global" url = SCAN_URL.format(market=market.lower()) headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT} resp = requests.post(url, json.dumps(kwargs), headers=headers) symbols_list = json.loads(resp.text)["data"] return symbols_list TVDataT = tp.TypeVar("TVDataT", bound="TVData") class TVData(RemoteData): """Data class for fetching from TradingView. See `TVData.fetch_symbol` for arguments. !!! note If you're getting the error "Please confirm that you are not a robot by clicking the captcha box." when attempting to authenticate, use `auth_token` instead of `username` and `password`. To get the authentication token, go to TradingView, log in, visit any chart, open your console's developer tools, and search for "auth_token". Usage: * Set up the credentials globally (optional): ```pycon >>> from vectorbtpro import * >>> vbt.TVData.set_custom_settings( ... client_config=dict( ... username="YOUR_USERNAME", ... password="YOUR_PASSWORD", ... auth_token="YOUR_AUTH_TOKEN", # optional, instead of username and password ... ) ... ) ``` * Pull data: ```pycon >>> data = vbt.TVData.pull( ... "NASDAQ:AAPL", ... timeframe="1 hour" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.tv") @classmethod def list_symbols( cls, *, exchange_pattern: tp.Optional[str] = None, symbol_pattern: tp.Optional[str] = None, use_regex: bool = False, sort: bool = True, client: tp.Optional[TVClient] = None, client_config: tp.DictLike = None, text: tp.Optional[str] = None, exchange: tp.Optional[str] = None, pages: tp.Optional[int] = None, delay: tp.Optional[int] = None, retries: tp.Optional[int] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, market: tp.Optional[str] = None, markets: tp.Optional[tp.List[str]] = None, fields: tp.Optional[tp.MaybeIterable[str]] = None, filter_by: tp.Union[None, tp.Callable, CustomTemplate] = None, groups: tp.Optional[tp.MaybeIterable[tp.Dict[str, tp.MaybeIterable[str]]]] = None, template_context: tp.KwargsLike = None, return_field_data: bool = False, **scanner_kwargs, ) -> tp.Union[tp.List[str], tp.List[tp.Kwargs]]: """List all symbols. Uses symbol search when either `text` or `exchange` is provided (returns a subset of symbols). Otherwise, uses the market scanner (returns all symbols, big payload). When using the market scanner, use `market` to filter by one or multiple markets. For the list of available markets, see `MARKET_LIST`. Use `fields` to make the market scanner return additional information that can be used for filtering with `filter_by`. Such information is passed to the function as a dictionary where fields are keys. The function can also be a template that can use the same information provided as a context, or a list of values that should be matched against the values corresponding to their fields. For the list of available fields, see `FIELD_LIST`. Argument `fields` can also be "all". Set `return_field_data` to True to return a list with (filtered) field data. Use `groups` to provide a single dictionary or a list of dictionaries with groups. Each dictionary can be provided either in a compressed format, such as `dict(index=index)`, or in a full format, such as `dict(type="index", values=[index])`. Keyword arguments `scanner_kwargs` are encoded and passed directly to the market scanner. Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each exchange against `exchange_pattern` and each symbol against `symbol_pattern`. Usage: * List all symbols (market scanner): ```pycon >>> from vectorbtpro import * >>> vbt.TVData.list_symbols() ``` * Search for symbols matching a pattern (market scanner, client-side): ```pycon >>> vbt.TVData.list_symbols(symbol_pattern="BTC*") ``` * Search for exchanges matching a pattern (market scanner, client-side): ```pycon >>> vbt.TVData.list_symbols(exchange_pattern="NASDAQ") ``` * Search for symbols containing a text (symbol search, server-side): ```pycon >>> vbt.TVData.list_symbols(text="BTC") ``` * List symbols from an exchange (symbol search): ```pycon >>> vbt.TVData.list_symbols(exchange="NASDAQ") ``` * List symbols from a market (market scanner): ```pycon >>> vbt.TVData.list_symbols(market="poland") ``` * List index constituents (market scanner): ```pycon >>> vbt.TVData.list_symbols(groups=dict(index="NASDAQ:NDX")) ``` * Filter symbols by fields using a function (market scanner): ```pycon >>> vbt.TVData.list_symbols( ... market="america", ... fields=["sector"], ... filter_by=lambda context: context["sector"] == "Technology Services" ... ) ``` * Filter symbols by fields using a template (market scanner): ```pycon >>> vbt.TVData.list_symbols( ... market="america", ... fields=["sector"], ... filter_by=vbt.RepEval("sector == 'Technology Services'") ... ) ``` """ pages = cls.resolve_custom_setting(pages, "pages", sub_path="search", sub_path_only=True) delay = cls.resolve_custom_setting(delay, "delay", sub_path="search", sub_path_only=True) retries = cls.resolve_custom_setting(retries, "retries", sub_path="search", sub_path_only=True) show_progress = cls.resolve_custom_setting( show_progress, "show_progress", sub_path="search", sub_path_only=True ) pbar_kwargs = cls.resolve_custom_setting( pbar_kwargs, "pbar_kwargs", sub_path="search", sub_path_only=True, merge=True ) markets = cls.resolve_custom_setting(markets, "markets", sub_path="scanner", sub_path_only=True) fields = cls.resolve_custom_setting(fields, "fields", sub_path="scanner", sub_path_only=True) filter_by = cls.resolve_custom_setting(filter_by, "filter_by", sub_path="scanner", sub_path_only=True) groups = cls.resolve_custom_setting(groups, "groups", sub_path="scanner", sub_path_only=True) template_context = cls.resolve_custom_setting( template_context, "template_context", sub_path="scanner", sub_path_only=True, merge=True ) scanner_kwargs = cls.resolve_custom_setting( scanner_kwargs, "scanner_kwargs", sub_path="scanner", sub_path_only=True, merge=True ) if market is None and text is None and exchange is None: market = "global" if market is not None and (text is not None or exchange is not None): raise ValueError("Please provide either market, or text and/or exchange") if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) if market is None: data = client.search_symbol( text=text, exchange=exchange, pages=pages, delay=delay, retries=retries, show_progress=show_progress, pbar_kwargs=pbar_kwargs, ) all_symbols = map(lambda x: x["exchange"] + ":" + x["symbol"], data) return_field_data = False else: if markets is not None: scanner_kwargs["markets"] = markets if fields is not None: if "columns" in scanner_kwargs: raise ValueError("Use fields instead of columns") if isinstance(fields, str): if fields.lower() == "all": fields = FIELD_LIST else: fields = [fields] scanner_kwargs["columns"] = fields if groups is not None: if isinstance(groups, dict): groups = [groups] new_groups = [] for group in groups: if "type" in group: new_groups.append(group) else: for k, v in group.items(): if isinstance(v, str): v = [v] new_groups.append(dict(type=k, values=v)) groups = new_groups if "symbols" in scanner_kwargs: scanner_kwargs["symbols"] = dict(scanner_kwargs["symbols"]) else: scanner_kwargs["symbols"] = dict() scanner_kwargs["symbols"]["groups"] = groups if filter_by is not None: if isinstance(filter_by, str): filter_by = [filter_by] data = client.scan_symbols(market.lower(), **scanner_kwargs) if data is None: raise ValueError("No data returned by TradingView") all_symbols = [] for item in data: if fields is not None: item = {"symbol": item["s"], **dict(zip(fields, item["d"]))} else: item = {"symbol": item["s"]} if filter_by is not None: if fields is not None: context = merge_dicts(item, template_context) else: raise ValueError("Must provide fields for filter_by") if isinstance(filter_by, CustomTemplate): if not filter_by.substitute(context, eval_id="filter_by"): continue elif callable(filter_by): if not filter_by(context): continue else: if len(fields) != len(filter_by): raise ValueError("Fields and filter_by must have the same number of values") conditions_met = True for i in range(len(fields)): if context[fields[i]] != filter_by[i]: conditions_met = False break if not conditions_met: continue if return_field_data: all_symbols.append(item) else: all_symbols.append(item["symbol"]) found_symbols = [] for symbol in all_symbols: if return_field_data: item = symbol symbol = item["symbol"] else: item = symbol if '"symbol"' in symbol: continue if exchange_pattern is not None: if not cls.key_match(symbol.split(":")[0], exchange_pattern, use_regex=use_regex): continue if symbol_pattern is not None: if not cls.key_match(symbol.split(":")[1], symbol_pattern, use_regex=use_regex): continue found_symbols.append(item) if sort: if return_field_data: return sorted(found_symbols, key=lambda x: x["symbol"]) return sorted(dict.fromkeys(found_symbols)) if return_field_data: return found_symbols return list(dict.fromkeys(found_symbols)) @classmethod def resolve_client(cls, client: tp.Optional[TVClient] = None, **client_config) -> TVClient: """Resolve the client. If provided, must be of the type `TVClient`. Otherwise, will be created using `client_config`.""" client = cls.resolve_custom_setting(client, "client") if client_config is None: client_config = {} has_client_config = len(client_config) > 0 client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True) if client is None: client = TVClient(**client_config) elif has_client_config: raise ValueError("Cannot apply client_config to already initialized client") return client @classmethod def fetch_symbol( cls, symbol: str, client: tp.Optional[TVClient] = None, client_config: tp.KwargsLike = None, exchange: tp.Optional[str] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, fut_contract: tp.Optional[int] = None, adjustment: tp.Optional[str] = None, extended_session: tp.Optional[bool] = None, pro_data: tp.Optional[bool] = None, limit: tp.Optional[int] = None, delay: tp.Optional[int] = None, retries: tp.Optional[int] = None, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from TradingView. Args: symbol (str): Symbol. Symbol must be in the `EXCHANGE:SYMBOL` format if `exchange` is None. client (TVClient): Client. See `TVData.resolve_client`. client_config (dict): Client config. See `TVData.resolve_client`. exchange (str): Exchange. Can be omitted if already provided via `symbol`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. fut_contract (int): None for cash, 1 for continuous current contract in front, 2 for continuous next contract in front. adjustment (str): Adjustment. Either "splits" (default) or "dividends". extended_session (bool): Regular session if False, extended session if True. pro_data (bool): Whether to use pro data. limit (int): The maximum number of returned items. delay (float): Time to sleep after each request (in seconds). retries (int): The number of retries on failure to fetch data. For defaults, see `custom.tv` in `vectorbtpro._settings.data`. """ if client_config is None: client_config = {} client = cls.resolve_client(client=client, **client_config) exchange = cls.resolve_custom_setting(exchange, "exchange") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") fut_contract = cls.resolve_custom_setting(fut_contract, "fut_contract") adjustment = cls.resolve_custom_setting(adjustment, "adjustment") extended_session = cls.resolve_custom_setting(extended_session, "extended_session") pro_data = cls.resolve_custom_setting(pro_data, "pro_data") limit = cls.resolve_custom_setting(limit, "limit") delay = cls.resolve_custom_setting(delay, "delay") retries = cls.resolve_custom_setting(retries, "retries") freq = timeframe if not isinstance(timeframe, str): raise ValueError(f"Invalid timeframe: '{timeframe}'") split = dt.split_freq_str(timeframe) if split is None: raise ValueError(f"Invalid timeframe: '{timeframe}'") multiplier, unit = split if unit == "s": interval = f"{str(multiplier)}S" elif unit == "m": interval = str(multiplier) elif unit == "h": interval = f"{str(multiplier)}H" elif unit == "D": interval = f"{str(multiplier)}D" elif unit == "W": interval = f"{str(multiplier)}W" elif unit == "M": interval = f"{str(multiplier)}M" else: raise ValueError(f"Invalid timeframe: '{timeframe}'") for i in range(retries): try: df = client.get_hist( symbol=symbol, exchange=exchange, interval=interval, fut_contract=fut_contract, adjustment=adjustment, extended_session=extended_session, pro_data=pro_data, limit=limit, ) break except Exception as e: if i == retries - 1: raise e if delay is not None: time.sleep(delay) df.rename( columns={ "symbol": "Symbol", "open": "Open", "high": "High", "low": "Low", "close": "Close", "volume": "Volume", }, inplace=True, ) if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: df = df.tz_localize("utc") if "Symbol" in df: del df["Symbol"] if "Open" in df.columns: df["Open"] = df["Open"].astype(float) if "High" in df.columns: df["High"] = df["High"].astype(float) if "Low" in df.columns: df["Low"] = df["Low"].astype(float) if "Close" in df.columns: df["Close"] = df["Close"].astype(float) if "Volume" in df.columns: df["Volume"] = df["Volume"].astype(float) return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `YFData`.""" import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.data.custom.remote import RemoteData from vectorbtpro.generic import nb as generic_nb from vectorbtpro.utils import datetime_ as dt from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig from vectorbtpro.utils.parsing import get_func_kwargs __all__ = [ "YFData", ] __pdoc__ = {} class YFData(RemoteData): """Data class for fetching from Yahoo Finance. See https://github.com/ranaroussi/yfinance for API. See `YFData.fetch_symbol` for arguments. Usage: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull( ... "BTC-USD", ... start="2020-01-01", ... end="2021-01-01", ... timeframe="1 day" ... ) ``` """ _settings_path: tp.SettingsPath = dict(custom="data.custom.yf") _feature_config: tp.ClassVar[Config] = HybridConfig( { "Dividends": dict( resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( resampler, generic_nb.sum_reduce_nb, ) ), "Stock Splits": dict( resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( resampler, generic_nb.nonzero_prod_reduce_nb, ) ), "Capital Gains": dict( resample_func=lambda self, obj, resampler: obj.vbt.resample_apply( resampler, generic_nb.sum_reduce_nb, ) ), } ) @property def feature_config(self) -> Config: return self._feature_config @classmethod def fetch_symbol( cls, symbol: str, period: tp.Optional[str] = None, start: tp.Optional[tp.DatetimeLike] = None, end: tp.Optional[tp.DatetimeLike] = None, timeframe: tp.Optional[str] = None, tz: tp.TimezoneLike = None, **history_kwargs, ) -> tp.SymbolData: """Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Yahoo Finance. Args: symbol (str): Symbol. period (str): Period. start (any): Start datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. end (any): End datetime. See `vectorbtpro.utils.datetime_.to_tzaware_datetime`. timeframe (str): Timeframe. Allows human-readable strings such as "15 minutes". tz (any): Timezone. See `vectorbtpro.utils.datetime_.to_timezone`. **history_kwargs: Keyword arguments passed to `yfinance.base.TickerBase.history`. For defaults, see `custom.yf` in `vectorbtpro._settings.data`. !!! warning Data coming from Yahoo is not the most stable data out there. Yahoo may manipulate data how they want, add noise, return missing data points (see volume in the example below), etc. It's only used in vectorbt for demonstration purposes. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("yfinance") import yfinance as yf period = cls.resolve_custom_setting(period, "period") start = cls.resolve_custom_setting(start, "start") end = cls.resolve_custom_setting(end, "end") timeframe = cls.resolve_custom_setting(timeframe, "timeframe") tz = cls.resolve_custom_setting(tz, "tz") history_kwargs = cls.resolve_custom_setting(history_kwargs, "history_kwargs", merge=True) ticker = yf.Ticker(symbol) def_history_kwargs = get_func_kwargs(yf.Tickers.history) ticker_tz = ticker._get_ticker_tz( history_kwargs.get("proxy", def_history_kwargs["proxy"]), history_kwargs.get("timeout", def_history_kwargs["timeout"]), ) if tz is None: tz = ticker_tz if start is not None: start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=ticker_tz) if end is not None: end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=ticker_tz) freq = timeframe split = dt.split_freq_str(timeframe) if split is not None: multiplier, unit = split if unit == "D": unit = "d" elif unit == "W": unit = "wk" elif unit == "M": unit = "mo" timeframe = str(multiplier) + unit df = ticker.history(period=period, start=start, end=end, interval=timeframe, **history_kwargs) if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None: df = df.tz_localize(ticker_tz) if not df.empty: if start is not None: if df.index[0] < start: df = df[df.index >= start] if end is not None: if df.index[-1] >= end: df = df[df.index < end] return df, dict(tz=tz, freq=freq) def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData: fetch_kwargs = self.select_fetch_kwargs(symbol) fetch_kwargs["start"] = self.select_last_index(symbol) kwargs = merge_dicts(fetch_kwargs, kwargs) return self.fetch_symbol(symbol, **kwargs) YFData.override_feature_config_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with data sources.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.data.base import * from vectorbtpro.data.custom import * from vectorbtpro.data.decorators import * from vectorbtpro.data.nb import * from vectorbtpro.data.saver import * from vectorbtpro.data.updater import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with data sources.""" import inspect import string import traceback from pathlib import Path import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.indexes import stack_indexes from vectorbtpro.base.merging import column_stack_arrays, is_merge_func_from_config from vectorbtpro.base.reshaping import to_any_array, to_pd_array, to_1d_array, to_2d_array, broadcast_to from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.data.decorators import attach_symbol_dict_methods from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.generic.drawdowns import Drawdowns from vectorbtpro.ohlcv.nb import mirror_ohlc_nb from vectorbtpro.ohlcv.enums import PriceFeature from vectorbtpro.returns.accessors import ReturnsAccessor from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.attr_ import get_dict_attr from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig, copy_dict from vectorbtpro.utils.decorators import cached_property, hybrid_method from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.execution import Task, NoResult, NoResultsException, filter_out_no_results, execute from vectorbtpro.utils.merging import MergeFunc from vectorbtpro.utils.parsing import get_func_arg_names, extend_args from vectorbtpro.utils.path_ import check_mkdir from vectorbtpro.utils.pickling import pdict, RecState from vectorbtpro.utils.template import Rep, RepEval, CustomTemplate, substitute_templates from vectorbtpro.utils.warnings_ import warn from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg try: if not tp.TYPE_CHECKING: raise ImportError from sqlalchemy import Engine as EngineT except ImportError: EngineT = "Engine" try: if not tp.TYPE_CHECKING: raise ImportError from duckdb import DuckDBPyConnection as DuckDBPyConnectionT except ImportError: DuckDBPyConnectionT = "DuckDBPyConnection" __all__ = [ "key_dict", "feature_dict", "symbol_dict", "run_func_dict", "run_arg_dict", "Data", ] __pdoc__ = {} class key_dict(pdict): """Dict that contains features or symbols as keys.""" pass class feature_dict(key_dict): """Dict that contains features as keys.""" pass class symbol_dict(key_dict): """Dict that contains symbols as keys.""" pass class run_func_dict(pdict): """Dict that contains function names as keys for `Data.run`.""" pass class run_arg_dict(pdict): """Dict that contains argument names as keys for `Data.run`.""" pass BaseDataMixinT = tp.TypeVar("BaseDataMixinT", bound="BaseDataMixin") class BaseDataMixin(Base): """Base mixin class for working with data.""" @property def feature_wrapper(self) -> ArrayWrapper: """Column wrapper.""" raise NotImplementedError @property def symbol_wrapper(self) -> ArrayWrapper: """Symbol wrapper.""" raise NotImplementedError @property def features(self) -> tp.List[tp.Feature]: """List of features.""" return self.feature_wrapper.columns.tolist() @property def symbols(self) -> tp.List[tp.Symbol]: """List of symbols.""" return self.symbol_wrapper.columns.tolist() @classmethod def has_multiple_keys(cls, keys: tp.MaybeKeys) -> bool: """Check whether there are one or multiple keys.""" if checks.is_hashable(keys): return False elif checks.is_sequence(keys): return True raise TypeError("Keys must be either a hashable or a sequence of hashable") @classmethod def prepare_key(cls, key: tp.Key) -> tp.Key: """Prepare a key.""" if isinstance(key, tuple): return tuple([cls.prepare_key(k) for k in key]) if isinstance(key, str): return key.lower().strip().replace(" ", "_") return key def get_feature_idx(self, feature: tp.Feature, raise_error: bool = False) -> int: """Return the index of a feature.""" # shortcut columns = self.feature_wrapper.columns if not columns.has_duplicates: if feature in columns: return columns.get_loc(feature) feature = self.prepare_key(feature) found_indices = [] for i, c in enumerate(self.features): c = self.prepare_key(c) if feature == c: found_indices.append(i) if len(found_indices) == 0: if raise_error: raise ValueError(f"No features match the feature '{str(feature)}'") return -1 if len(found_indices) == 1: return found_indices[0] raise ValueError(f"Multiple features match the feature '{str(feature)}'") def get_symbol_idx(self, symbol: tp.Symbol, raise_error: bool = False) -> int: """Return the index of a symbol.""" # shortcut columns = self.symbol_wrapper.columns if not columns.has_duplicates: if symbol in columns: return columns.get_loc(symbol) symbol = self.prepare_key(symbol) found_indices = [] for i, c in enumerate(self.symbols): c = self.prepare_key(c) if symbol == c: found_indices.append(i) if len(found_indices) == 0: if raise_error: raise ValueError(f"No symbols match the symbol '{str(symbol)}'") return -1 if len(found_indices) == 1: return found_indices[0] raise ValueError(f"Multiple symbols match the symbol '{str(symbol)}'") def select_feature_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT: """Select one or more features by index. Returns a new instance.""" raise NotImplementedError def select_symbol_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT: """Select one or more symbols by index. Returns a new instance.""" raise NotImplementedError def select_features(self: BaseDataMixinT, features: tp.MaybeFeatures, **kwargs) -> BaseDataMixinT: """Select one or more features. Returns a new instance.""" if self.has_multiple_keys(features): feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features] else: feature_idxs = self.get_feature_idx(features, raise_error=True) return self.select_feature_idxs(feature_idxs, **kwargs) def select_symbols(self: BaseDataMixinT, symbols: tp.MaybeSymbols, **kwargs) -> BaseDataMixinT: """Select one or more symbols. Returns a new instance.""" if self.has_multiple_keys(symbols): symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols] else: symbol_idxs = self.get_symbol_idx(symbols, raise_error=True) return self.select_symbol_idxs(symbol_idxs, **kwargs) def get( self, features: tp.Optional[tp.MaybeFeatures] = None, symbols: tp.Optional[tp.MaybeSymbols] = None, feature: tp.Optional[tp.Feature] = None, symbol: tp.Optional[tp.Symbol] = None, **kwargs, ) -> tp.MaybeTuple[tp.SeriesFrame]: """Get one or more features of one or more symbols of data.""" raise NotImplementedError def has_feature(self, feature: tp.Feature) -> bool: """Whether feature exists.""" feature_idx = self.get_feature_idx(feature, raise_error=False) return feature_idx != -1 def has_symbol(self, symbol: tp.Symbol) -> bool: """Whether symbol exists.""" symbol_idx = self.get_symbol_idx(symbol, raise_error=False) return symbol_idx != -1 def assert_has_feature(self, feature: tp.Feature) -> None: """Assert that feature exists.""" self.get_feature_idx(feature, raise_error=True) def assert_has_symbol(self, symbol: tp.Symbol) -> None: """Assert that symbol exists.""" self.get_symbol_idx(symbol, raise_error=True) def get_feature( self, feature: tp.Union[int, tp.Feature], raise_error: bool = False, ) -> tp.Optional[tp.SeriesFrame]: """Get feature that match a feature index or label.""" if checks.is_int(feature): return self.get(features=self.features[feature]) feature_idx = self.get_feature_idx(feature, raise_error=raise_error) if feature_idx == -1: return None return self.get(features=self.features[feature_idx]) def get_symbol( self, symbol: tp.Union[int, tp.Symbol], raise_error: bool = False, ) -> tp.Optional[tp.SeriesFrame]: """Get symbol that match a symbol index or label.""" if checks.is_int(symbol): return self.get(symbol=self.symbols[symbol]) symbol_idx = self.get_symbol_idx(symbol, raise_error=raise_error) if symbol_idx == -1: return None return self.get(symbol=self.symbols[symbol_idx]) OHLCDataMixinT = tp.TypeVar("OHLCDataMixinT", bound="OHLCDataMixin") class OHLCDataMixin(BaseDataMixin): """Mixin class for working with OHLC data.""" @property def open(self) -> tp.Optional[tp.SeriesFrame]: """Open.""" return self.get_feature("Open") @property def high(self) -> tp.Optional[tp.SeriesFrame]: """High.""" return self.get_feature("High") @property def low(self) -> tp.Optional[tp.SeriesFrame]: """Low.""" return self.get_feature("Low") @property def close(self) -> tp.Optional[tp.SeriesFrame]: """Close.""" return self.get_feature("Close") @property def volume(self) -> tp.Optional[tp.SeriesFrame]: """Volume.""" return self.get_feature("Volume") @property def trade_count(self) -> tp.Optional[tp.SeriesFrame]: """Trade count.""" return self.get_feature("Trade count") @property def vwap(self) -> tp.Optional[tp.SeriesFrame]: """VWAP.""" return self.get_feature("VWAP") @property def hlc3(self) -> tp.SeriesFrame: """HLC/3.""" high = self.get_feature("High", raise_error=True) low = self.get_feature("Low", raise_error=True) close = self.get_feature("Close", raise_error=True) return (high + low + close) / 3 @property def ohlc4(self) -> tp.SeriesFrame: """OHLC/4.""" open = self.get_feature("Open", raise_error=True) high = self.get_feature("High", raise_error=True) low = self.get_feature("Low", raise_error=True) close = self.get_feature("Close", raise_error=True) return (open + high + low + close) / 4 @property def has_any_ohlc(self) -> bool: """Whether the instance has any of the OHLC features.""" return ( self.has_feature("Open") or self.has_feature("High") or self.has_feature("Low") or self.has_feature("Close") ) @property def has_ohlc(self) -> bool: """Whether the instance has all the OHLC features.""" return ( self.has_feature("Open") and self.has_feature("High") and self.has_feature("Low") and self.has_feature("Close") ) @property def has_any_ohlcv(self) -> bool: """Whether the instance has any of the OHLCV features.""" return self.has_any_ohlc or self.has_feature("Volume") @property def has_ohlcv(self) -> bool: """Whether the instance has all the OHLCV features.""" return self.has_ohlc and self.has_feature("Volume") @property def ohlc(self: OHLCDataMixinT) -> OHLCDataMixinT: """Return a `OHLCDataMixin` instance with the OHLC features only.""" open_idx = self.get_feature_idx("Open", raise_error=True) high_idx = self.get_feature_idx("High", raise_error=True) low_idx = self.get_feature_idx("Low", raise_error=True) close_idx = self.get_feature_idx("Close", raise_error=True) return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx]) @property def ohlcv(self: OHLCDataMixinT) -> OHLCDataMixinT: """Return a `OHLCDataMixin` instance with the OHLCV features only.""" open_idx = self.get_feature_idx("Open", raise_error=True) high_idx = self.get_feature_idx("High", raise_error=True) low_idx = self.get_feature_idx("Low", raise_error=True) close_idx = self.get_feature_idx("Close", raise_error=True) volume_idx = self.get_feature_idx("Volume", raise_error=True) return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx, volume_idx]) def get_returns_acc(self, **kwargs) -> ReturnsAccessor: """Return accessor of type `vectorbtpro.returns.accessors.ReturnsAccessor`.""" return ReturnsAccessor.from_value( self.get_feature("Close", raise_error=True), wrapper=self.symbol_wrapper, return_values=False, **kwargs, ) @property def returns_acc(self) -> ReturnsAccessor: """`OHLCDataMixin.get_returns_acc` with default arguments.""" return self.get_returns_acc() def get_returns(self, **kwargs) -> tp.SeriesFrame: """Returns.""" return ReturnsAccessor.from_value( self.get_feature("Close", raise_error=True), wrapper=self.symbol_wrapper, return_values=True, **kwargs, ) @property def returns(self) -> tp.SeriesFrame: """`OHLCDataMixin.get_returns` with default arguments.""" return self.get_returns() def get_log_returns(self, **kwargs) -> tp.SeriesFrame: """Log returns.""" return ReturnsAccessor.from_value( self.get_feature("Close", raise_error=True), wrapper=self.symbol_wrapper, return_values=True, log_returns=True, **kwargs, ) @property def log_returns(self) -> tp.SeriesFrame: """`OHLCDataMixin.get_log_returns` with default arguments.""" return self.get_log_returns() def get_daily_returns(self, **kwargs) -> tp.SeriesFrame: """Daily returns.""" return ReturnsAccessor.from_value( self.get_feature("Close", raise_error=True), wrapper=self.symbol_wrapper, return_values=False, **kwargs, ).daily() @property def daily_returns(self) -> tp.SeriesFrame: """`OHLCDataMixin.get_daily_returns` with default arguments.""" return self.get_daily_returns() def get_daily_log_returns(self, **kwargs) -> tp.SeriesFrame: """Daily log returns.""" return ReturnsAccessor.from_value( self.get_feature("Close", raise_error=True), wrapper=self.symbol_wrapper, return_values=False, log_returns=True, **kwargs, ).daily() @property def daily_log_returns(self) -> tp.SeriesFrame: """`OHLCDataMixin.get_daily_log_returns` with default arguments.""" return self.get_daily_log_returns() def get_drawdowns(self, **kwargs) -> Drawdowns: """Generate drawdown records. See `vectorbtpro.generic.drawdowns.Drawdowns`.""" return Drawdowns.from_price( open=self.get_feature("Open", raise_error=True), high=self.get_feature("High", raise_error=True), low=self.get_feature("Low", raise_error=True), close=self.get_feature("Close", raise_error=True), **kwargs, ) @property def drawdowns(self) -> Drawdowns: """`OHLCDataMixin.get_drawdowns` with default arguments.""" return self.get_drawdowns() DataT = tp.TypeVar("DataT", bound="Data") class MetaData(type(Analyzable)): """Metaclass for `Data`.""" @property def feature_config(cls) -> Config: """Feature config.""" return cls._feature_config @attach_symbol_dict_methods class Data(Analyzable, OHLCDataMixin, metaclass=MetaData): """Class that downloads, updates, and manages data coming from a data source.""" _settings_path: tp.SettingsPath = dict(base="data") _writeable_attrs: tp.WriteableAttrs = {"_feature_config"} _feature_config: tp.ClassVar[Config] = HybridConfig() _key_dict_attrs = [ "fetch_kwargs", "returned_kwargs", "last_index", "delisted", "classes", ] """Attributes that subclass either `feature_dict` or `symbol_dict`.""" _data_dict_type_attrs = [ "classes", ] """Attributes that subclass the data dict type.""" _updatable_attrs = [ "fetch_kwargs", "returned_kwargs", "classes", ] """Attributes that have a method for updating.""" @property def feature_config(self) -> Config: """Column config of `${cls_name}`. ```python ${feature_config} ``` Returns `${cls_name}._feature_config`, which gets (hybrid-) copied upon creation of each instance. Thus, changing this config won't affect the class. To change fields, you can either change the config in-place, override this property, or overwrite the instance variable `${cls_name}._feature_config`. """ return self._feature_config def use_feature_config_of(self, cls: tp.Type[DataT]) -> None: """Copy feature config from another `Data` class.""" self._feature_config = cls.feature_config.copy() @classmethod def modify_state(cls, rec_state: RecState) -> RecState: # Ensure backward compatibility if "_column_config" in rec_state.attr_dct and "_feature_config" not in rec_state.attr_dct: new_attr_dct = dict(rec_state.attr_dct) new_attr_dct["_feature_config"] = new_attr_dct.pop("_column_config") rec_state = RecState( init_args=rec_state.init_args, init_kwargs=rec_state.init_kwargs, attr_dct=new_attr_dct, ) if "single_symbol" in rec_state.init_kwargs and "single_key" not in rec_state.init_kwargs: new_init_kwargs = dict(rec_state.init_kwargs) new_init_kwargs["single_key"] = new_init_kwargs.pop("single_symbol") rec_state = RecState( init_args=rec_state.init_args, init_kwargs=new_init_kwargs, attr_dct=rec_state.attr_dct, ) if "symbol_classes" in rec_state.init_kwargs and "classes" not in rec_state.init_kwargs: new_init_kwargs = dict(rec_state.init_kwargs) new_init_kwargs["classes"] = new_init_kwargs.pop("symbol_classes") rec_state = RecState( init_args=rec_state.init_args, init_kwargs=new_init_kwargs, attr_dct=rec_state.attr_dct, ) return rec_state @classmethod def fix_data_dict_type(cls, data: dict) -> tp.Union[feature_dict, symbol_dict]: """Fix dict type for data.""" checks.assert_instance_of(data, dict, arg_name="data") if not isinstance(data, key_dict): data = symbol_dict(data) return data @classmethod def fix_dict_types_in_kwargs( cls, data_type: tp.Type[tp.Union[feature_dict, symbol_dict]], **kwargs: tp.Kwargs, ) -> tp.Kwargs: """Fix dict types in keyword arguments.""" for attr in cls._key_dict_attrs: if attr in kwargs: attr_value = kwargs[attr] if attr_value is None: attr_value = {} checks.assert_instance_of(attr_value, dict, arg_name=attr) if not isinstance(attr_value, key_dict): attr_value = data_type(attr_value) if attr in cls._data_dict_type_attrs: checks.assert_instance_of(attr_value, data_type, arg_name=attr) kwargs[attr] = attr_value return kwargs @hybrid_method def row_stack( cls_or_self: tp.MaybeType[DataT], *objs: tp.MaybeTuple[DataT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> DataT: """Stack multiple `Data` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Data): raise TypeError("Each object to be merged must be an instance of Data") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) keys = set() for obj in objs: keys = keys.union(set(obj.data.keys())) data_type = None for obj in objs: if len(keys.difference(set(obj.data.keys()))) > 0: if isinstance(obj.data, feature_dict): raise ValueError("Objects to be merged must have the same features") else: raise ValueError("Objects to be merged must have the same symbols") if data_type is None: data_type = type(obj.data) elif not isinstance(obj.data, data_type): raise TypeError("Objects to be merged must have the same dict type for data") if "data" not in kwargs: new_data = data_type() for k in objs[0].data.keys(): new_data[k] = kwargs["wrapper"].row_stack_arrs(*[obj.data[k] for obj in objs], group_by=False) kwargs["data"] = new_data kwargs["data"] = cls.fix_data_dict_type(kwargs["data"]) for attr in cls._key_dict_attrs: if attr not in kwargs: attr_data_type = None for obj in objs: v = getattr(obj, attr) if attr_data_type is None: attr_data_type = type(v) elif not isinstance(v, attr_data_type): raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'") kwargs[attr] = getattr(objs[-1], attr) kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs) return cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[DataT], *objs: tp.MaybeTuple[DataT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> DataT: """Stack multiple `Data` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Data): raise TypeError("Each object to be merged must be an instance of Data") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], **wrapper_kwargs, ) keys = set() for obj in objs: keys = keys.union(set(obj.data.keys())) data_type = None for obj in objs: if len(keys.difference(set(obj.data.keys()))) > 0: if isinstance(obj.data, feature_dict): raise ValueError("Objects to be merged must have the same features") else: raise ValueError("Objects to be merged must have the same symbols") if data_type is None: data_type = type(obj.data) elif not isinstance(obj.data, data_type): raise TypeError("Objects to be merged must have the same dict type for data") if "data" not in kwargs: new_data = data_type() for k in objs[0].data.keys(): new_data[k] = kwargs["wrapper"].column_stack_arrs(*[obj.data[k] for obj in objs], group_by=False) kwargs["data"] = new_data kwargs["data"] = cls.fix_data_dict_type(kwargs["data"]) for attr in cls._key_dict_attrs: if attr not in kwargs: attr_data_type = None for obj in objs: v = getattr(obj, attr) if attr_data_type is None: attr_data_type = type(v) elif not isinstance(v, attr_data_type): raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'") if (issubclass(data_type, feature_dict) and issubclass(attr_data_type, symbol_dict)) or ( issubclass(data_type, symbol_dict) and issubclass(attr_data_type, feature_dict) ): kwargs[attr] = attr_data_type() for obj in objs: kwargs[attr].update(**getattr(obj, attr)) kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, data: tp.Union[feature_dict, symbol_dict], single_key: bool = True, classes: tp.Union[None, feature_dict, symbol_dict] = None, level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, fetch_kwargs: tp.Union[None, feature_dict, symbol_dict] = None, returned_kwargs: tp.Union[None, feature_dict, symbol_dict] = None, last_index: tp.Union[None, feature_dict, symbol_dict] = None, delisted: tp.Union[None, feature_dict, symbol_dict] = None, tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None, tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None, missing_index: tp.Optional[str] = None, missing_columns: tp.Optional[str] = None, **kwargs, ) -> None: Analyzable.__init__( self, wrapper, data=data, single_key=single_key, classes=classes, level_name=level_name, fetch_kwargs=fetch_kwargs, returned_kwargs=returned_kwargs, last_index=last_index, delisted=delisted, tz_localize=tz_localize, tz_convert=tz_convert, missing_index=missing_index, missing_columns=missing_columns, **kwargs, ) if len(set(map(self.prepare_key, data.keys()))) < len(list(map(self.prepare_key, data.keys()))): raise ValueError("Found duplicate keys in data dictionary") data = self.fix_data_dict_type(data) for obj in data.values(): checks.assert_meta_equal(obj, data[list(data.keys())[0]]) if len(data) > 1: single_key = False self._data = data self._single_key = single_key self._level_name = level_name self._tz_localize = tz_localize self._tz_convert = tz_convert self._missing_index = missing_index self._missing_columns = missing_columns attr_kwargs = dict() for attr in self._key_dict_attrs: attr_value = locals()[attr] attr_kwargs[attr] = attr_value attr_kwargs = self.fix_dict_types_in_kwargs(type(data), **attr_kwargs) for k, v in attr_kwargs.items(): setattr(self, "_" + k, v) # Copy writeable attrs self._feature_config = type(self)._feature_config.copy() def replace(self: DataT, **kwargs) -> DataT: """See `vectorbtpro.utils.config.Configured.replace`. Replaces the data's index and/or columns if they were changed in the wrapper.""" if "wrapper" in kwargs and "data" not in kwargs: wrapper = kwargs["wrapper"] if isinstance(wrapper, dict): new_index = wrapper.get("index", self.wrapper.index) new_columns = wrapper.get("columns", self.wrapper.columns) else: new_index = wrapper.index new_columns = wrapper.columns data = self.config["data"] new_data = {} index_changed = False columns_changed = False for k, v in data.items(): if isinstance(v, (pd.Series, pd.DataFrame)): if not checks.is_index_equal(v.index, new_index): v = v.copy(deep=False) v.index = new_index index_changed = True if isinstance(v, pd.DataFrame): if not checks.is_index_equal(v.columns, new_columns): v = v.copy(deep=False) v.columns = new_columns columns_changed = True new_data[k] = v if index_changed or columns_changed: kwargs["data"] = self.fix_data_dict_type(new_data) if columns_changed: rename = dict(zip(self.keys, new_columns)) for attr in self._key_dict_attrs: if attr not in kwargs: attr_value = getattr(self, attr) if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or ( self.symbol_oriented and isinstance(attr_value, feature_dict) ): kwargs[attr] = self.rename_in_dict(getattr(self, attr), rename) kwargs = self.fix_dict_types_in_kwargs(type(kwargs.get("data", self.data)), **kwargs) return Analyzable.replace(self, **kwargs) def indexing_func(self: DataT, *args, replace_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Perform indexing on `Data`.""" if replace_kwargs is None: replace_kwargs = {} wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) new_wrapper = wrapper_meta["new_wrapper"] new_data = self.dict_type() for k, v in self._data.items(): if wrapper_meta["rows_changed"]: v = v.iloc[wrapper_meta["row_idxs"]] if wrapper_meta["columns_changed"]: v = v.iloc[:, wrapper_meta["col_idxs"]] new_data[k] = v attr_dicts = dict() attr_dicts["last_index"] = type(self.last_index)() for k in self.last_index: attr_dicts["last_index"][k] = min([self.last_index[k], new_wrapper.index[-1]]) if wrapper_meta["columns_changed"]: new_symbols = new_wrapper.columns for attr in self._key_dict_attrs: attr_value = getattr(self, attr) if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or ( self.symbol_oriented and isinstance(attr_value, feature_dict) ): if attr in attr_dicts: attr_dicts[attr] = self.select_from_dict(attr_dicts[attr], new_symbols) else: attr_dicts[attr] = self.select_from_dict(attr_value, new_symbols) return self.replace(wrapper=new_wrapper, data=new_data, **attr_dicts, **replace_kwargs) @property def data(self) -> tp.Union[feature_dict, symbol_dict]: """Data dictionary. Has the type `feature_dict` for feature-oriented data or `symbol_dict` for symbol-oriented data.""" return self._data @property def dict_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]: """Return the dict type.""" return type(self.data) @property def column_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]: """Return the column type.""" if isinstance(self.data, feature_dict): return symbol_dict return feature_dict @property def feature_oriented(self) -> bool: """Whether data has features as keys.""" return issubclass(self.dict_type, feature_dict) @property def symbol_oriented(self) -> bool: """Whether data has symbols as keys.""" return issubclass(self.dict_type, symbol_dict) def get_keys(self, dict_type: tp.Type[tp.Union[feature_dict, symbol_dict]]) -> tp.List[tp.Key]: """Get keys depending on the provided dict type.""" checks.assert_subclass_of(dict_type, (feature_dict, symbol_dict), arg_name="dict_type") if issubclass(dict_type, feature_dict): return self.features return self.symbols @property def keys(self) -> tp.List[tp.Union[tp.Feature, tp.Symbol]]: """Keys in data. Features if `feature_dict` and symbols if `symbol_dict`.""" return list(self.data.keys()) @property def single_key(self) -> bool: """Whether there is only one key in `Data.data`.""" return self._single_key @property def single_feature(self) -> bool: """Whether there is only one feature in `Data.data`.""" if self.feature_oriented: return self.single_key return self.wrapper.ndim == 1 @property def single_symbol(self) -> bool: """Whether there is only one symbol in `Data.data`.""" if self.symbol_oriented: return self.single_key return self.wrapper.ndim == 1 @property def classes(self) -> tp.Union[feature_dict, symbol_dict]: """Key classes.""" return self._classes @property def feature_classes(self) -> tp.Optional[feature_dict]: """Feature classes.""" if self.feature_oriented: return self.classes return None @property def symbol_classes(self) -> tp.Optional[symbol_dict]: """Symbol classes.""" if self.symbol_oriented: return self.classes return None @hybrid_method def get_level_name( cls_or_self, keys: tp.Optional[tp.Keys] = None, level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, feature_oriented: tp.Optional[bool] = None, ) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]: """Get level name(s) for keys.""" if isinstance(cls_or_self, type): checks.assert_not_none(keys, arg_name="keys") checks.assert_not_none(feature_oriented, arg_name="feature_oriented") else: if keys is None: keys = cls_or_self.keys if level_name is None: level_name = cls_or_self._level_name if feature_oriented is None: feature_oriented = cls_or_self.feature_oriented first_key = keys[0] if isinstance(level_name, bool): if level_name: level_name = None else: return None if feature_oriented: key_prefix = "feature" else: key_prefix = "symbol" if isinstance(first_key, tuple): if level_name is None: level_name = ["%s_%d" % (key_prefix, i) for i in range(len(first_key))] if not checks.is_iterable(level_name) or isinstance(level_name, str): raise TypeError("Level name should be list-like for a MultiIndex") return tuple(level_name) if level_name is None: level_name = key_prefix return level_name @property def level_name(self) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]: """Level name(s) for keys. Keys are symbols or features depending on the data dict type. Must be a sequence if keys are tuples, otherwise a hashable. If False, no level names will be used.""" return self.get_level_name() @hybrid_method def get_key_index( cls_or_self, keys: tp.Optional[tp.Keys] = None, level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, feature_oriented: tp.Optional[bool] = None, ) -> tp.Index: """Get key index.""" if isinstance(cls_or_self, type): checks.assert_not_none(keys, arg_name="keys") else: if keys is None: keys = cls_or_self.keys level_name = cls_or_self.get_level_name(keys=keys, level_name=level_name, feature_oriented=feature_oriented) if isinstance(level_name, tuple): return pd.MultiIndex.from_tuples(keys, names=level_name) return pd.Index(keys, name=level_name) @property def key_index(self) -> tp.Index: """Key index.""" return self.get_key_index() @property def fetch_kwargs(self) -> tp.Union[feature_dict, symbol_dict]: """Keyword arguments of type `symbol_dict` initially passed to `Data.fetch_symbol`.""" return self._fetch_kwargs @property def returned_kwargs(self) -> tp.Union[feature_dict, symbol_dict]: """Keyword arguments of type `symbol_dict` returned by `Data.fetch_symbol`.""" return self._returned_kwargs @property def last_index(self) -> tp.Union[feature_dict, symbol_dict]: """Last fetched index per symbol of type `symbol_dict`.""" return self._last_index @property def delisted(self) -> tp.Union[feature_dict, symbol_dict]: """Delisted flag per symbol of type `symbol_dict`.""" return self._delisted @property def tz_localize(self) -> tp.Union[None, bool, tp.TimezoneLike]: """Timezone to localize a datetime-naive index to, which is initially passed to `Data.pull`.""" return self._tz_localize @property def tz_convert(self) -> tp.Union[None, bool, tp.TimezoneLike]: """Timezone to convert a datetime-aware to, which is initially passed to `Data.pull`.""" return self._tz_convert @property def missing_index(self) -> tp.Optional[str]: """Argument `missing` passed to `Data.align_index`.""" return self._missing_index @property def missing_columns(self) -> tp.Optional[str]: """Argument `missing` passed to `Data.align_columns`.""" return self._missing_columns # ############# Settings ############# # @classmethod def get_base_settings(cls, *args, **kwargs) -> dict: """`CustomData.get_settings` with `path_id="base"`.""" return cls.get_settings(*args, path_id="base", **kwargs) @classmethod def has_base_settings(cls, *args, **kwargs) -> bool: """`CustomData.has_settings` with `path_id="base"`.""" return cls.has_settings(*args, path_id="base", **kwargs) @classmethod def get_base_setting(cls, *args, **kwargs) -> tp.Any: """`CustomData.get_setting` with `path_id="base"`.""" return cls.get_setting(*args, path_id="base", **kwargs) @classmethod def has_base_setting(cls, *args, **kwargs) -> bool: """`CustomData.has_setting` with `path_id="base"`.""" return cls.has_setting(*args, path_id="base", **kwargs) @classmethod def resolve_base_setting(cls, *args, **kwargs) -> tp.Any: """`CustomData.resolve_setting` with `path_id="base"`.""" return cls.resolve_setting(*args, path_id="base", **kwargs) @classmethod def set_base_settings(cls, *args, **kwargs) -> None: """`CustomData.set_settings` with `path_id="base"`.""" cls.set_settings(*args, path_id="base", **kwargs) # ############# Iteration ############# # def items( self, over: str = "symbols", group_by: tp.GroupByLike = None, apply_group_by: bool = False, keep_2d: bool = False, key_as_index: bool = False, ) -> tp.Items: """Iterate over columns (or groups if grouped and `Wrapping.group_select` is True), keys, features, or symbols. The respective mode can be selected with `over`. See `vectorbtpro.base.wrapping.Wrapping.items` for iteration over columns. Iteration over keys supports `group_by` but doesn't support `apply_group_by`.""" if ( over.lower() == "columns" or (over.lower() == "symbols" and self.feature_oriented) or (over.lower() == "features" and self.symbol_oriented) ): for k, v in Analyzable.items( self, group_by=group_by, apply_group_by=apply_group_by, keep_2d=keep_2d, key_as_index=key_as_index, ): yield k, v elif ( over.lower() == "keys" or (over.lower() == "features" and self.feature_oriented) or (over.lower() == "symbols" and self.symbol_oriented) ): if apply_group_by: raise ValueError("Cannot apply grouping to keys") if group_by is not None: key_wrapper = self.get_key_wrapper(group_by=group_by) if key_wrapper.get_ndim() == 1: if key_as_index: yield key_wrapper.get_columns(), self else: yield key_wrapper.get_columns()[0], self else: for group, group_idxs in key_wrapper.grouper.iter_groups(key_as_index=key_as_index): if keep_2d or len(group_idxs) > 1: yield group, self.select_keys([self.keys[i] for i in group_idxs]) else: yield group, self.select_keys(self.keys[group_idxs[0]]) else: key_wrapper = self.get_key_wrapper(attach_classes=False) if key_wrapper.ndim == 1: if key_as_index: yield key_wrapper.columns, self else: yield key_wrapper.columns[0], self else: for i in range(len(key_wrapper.columns)): if key_as_index: key = key_wrapper.columns[[i]] else: key = key_wrapper.columns[i] if keep_2d: yield key, self.select_keys([key]) else: yield key, self.select_keys(key) else: raise ValueError(f"Invalid over: '{over}'") # ############# Getting ############# # def get_key_wrapper( self, keys: tp.Optional[tp.MaybeKeys] = None, attach_classes: bool = True, clean_index_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, **kwargs, ) -> ArrayWrapper: """Get wrapper with keys as columns. If `attach_classes` is True, attaches `Data.classes` by stacking them over the keys using `vectorbtpro.base.indexes.stack_indexes`. Other keyword arguments are passed to the constructor of the wrapper.""" if clean_index_kwargs is None: clean_index_kwargs = {} if keys is None: keys = self.keys ndim = 1 if self.single_key else 2 else: if self.has_multiple_keys(keys): ndim = 2 else: keys = [keys] ndim = 1 for key in keys: if self.feature_oriented: self.assert_has_feature(key) else: self.assert_has_symbol(key) new_columns = self.get_key_index(keys=keys) wrapper = self.wrapper.replace( columns=new_columns, ndim=ndim, grouper=None, **kwargs, ) if attach_classes: classes = [] all_have_classes = True for key in wrapper.columns: if key in self.classes: key_classes = self.classes[key] if len(key_classes) > 0: classes.append(key_classes) else: all_have_classes = False else: all_have_classes = False if len(classes) > 0 and not all_have_classes: if self.feature_oriented: raise ValueError("Some features have classes while others not") else: raise ValueError("Some symbols have classes while others not") if len(classes) > 0: classes_frame = pd.DataFrame(classes) if len(classes_frame.columns) == 1: classes_columns = pd.Index(classes_frame.iloc[:, 0]) else: classes_columns = pd.MultiIndex.from_frame(classes_frame) new_columns = stack_indexes((classes_columns, wrapper.columns), **clean_index_kwargs) wrapper = wrapper.replace(columns=new_columns) if group_by is not None: wrapper = wrapper.replace(group_by=group_by) return wrapper @cached_property def key_wrapper(self) -> ArrayWrapper: """Key wrapper.""" return self.get_key_wrapper() def get_feature_wrapper(self, features: tp.Optional[tp.MaybeFeatures] = None, **kwargs) -> ArrayWrapper: """Get wrapper with features as columns.""" if self.feature_oriented: return self.get_key_wrapper(keys=features, **kwargs) wrapper = self.wrapper if features is not None: wrapper = wrapper[features] return wrapper @cached_property def feature_wrapper(self) -> ArrayWrapper: return self.get_feature_wrapper() def get_symbol_wrapper(self, symbols: tp.Optional[tp.MaybeSymbols] = None, **kwargs) -> ArrayWrapper: """Get wrapper with symbols as columns.""" if self.symbol_oriented: return self.get_key_wrapper(keys=symbols, **kwargs) wrapper = self.wrapper if symbols is not None: wrapper = wrapper[symbols] return wrapper @cached_property def symbol_wrapper(self) -> ArrayWrapper: return self.get_symbol_wrapper() @property def ndim(self) -> int: """Number of dimensions. Based on the default symbol wrapper.""" return self.symbol_wrapper.ndim @property def shape(self) -> tp.Shape: """Shape. Based on the default symbol wrapper.""" return self.symbol_wrapper.shape @property def shape_2d(self) -> tp.Shape: """Shape as if the object was two-dimensional. Based on the default symbol wrapper.""" return self.symbol_wrapper.shape_2d @property def columns(self) -> tp.Index: """Columns. Based on the default symbol wrapper.""" return self.symbol_wrapper.columns @property def index(self) -> tp.Index: """Index. Based on the default symbol wrapper.""" return self.symbol_wrapper.index @property def freq(self) -> tp.Optional[tp.PandasFrequency]: """Frequency. Based on the default symbol wrapper.""" return self.symbol_wrapper.freq @property def features(self) -> tp.List[tp.Feature]: if self.feature_oriented: return self.keys return self.wrapper.columns.tolist() @property def symbols(self) -> tp.List[tp.Symbol]: if self.feature_oriented: return self.wrapper.columns.tolist() return self.keys def resolve_feature(self, feature: tp.Feature, raise_error: bool = False) -> tp.Optional[tp.Feature]: """Return the feature of this instance that matches the provided feature.""" feature_idx = self.get_feature_idx(feature, raise_error=raise_error) if feature_idx == -1: return None return self.features[feature_idx] def resolve_symbol(self, symbol: tp.Feature, raise_error: bool = False) -> tp.Optional[tp.Feature]: """Return the symbol of this instance that matches the provided symbol.""" symbol_idx = self.get_symbol_idx(symbol, raise_error=raise_error) if symbol_idx == -1: return None return self.symbols[symbol_idx] def resolve_key(self, key: tp.Key, raise_error: bool = False) -> tp.Optional[tp.Key]: """Return the key of this instance that matches the provided key.""" if self.feature_oriented: return self.resolve_feature(key, raise_error=raise_error) return self.resolve_symbol(key, raise_error=raise_error) def resolve_column(self, column: tp.Column, raise_error: bool = False) -> tp.Optional[tp.Column]: """Return the column of this instance that matches the provided column.""" if self.feature_oriented: return self.resolve_symbol(column, raise_error=raise_error) return self.resolve_feature(column, raise_error=raise_error) def resolve_features(self, features: tp.MaybeFeatures, raise_error: bool = True) -> tp.MaybeFeatures: """Return the features of this instance that match the provided features.""" if not self.has_multiple_keys(features): features = [features] single_feature = True else: single_feature = False new_features = [] for feature in features: new_features.append(self.resolve_feature(feature, raise_error=raise_error)) if single_feature: return new_features[0] return new_features def resolve_symbols(self, symbols: tp.MaybeSymbols, raise_error: bool = True) -> tp.MaybeSymbols: """Return the symbols of this instance that match the provided symbols.""" if not self.has_multiple_keys(symbols): symbols = [symbols] single_symbol = True else: single_symbol = False new_symbols = [] for symbol in symbols: new_symbols.append(self.resolve_symbol(symbol, raise_error=raise_error)) if single_symbol: return new_symbols[0] return new_symbols def resolve_keys(self, keys: tp.MaybeKeys, raise_error: bool = True) -> tp.MaybeKeys: """Return the keys of this instance that match the provided keys.""" if self.feature_oriented: return self.resolve_features(keys, raise_error=raise_error) return self.resolve_symbols(keys, raise_error=raise_error) def resolve_columns(self, columns: tp.MaybeColumns, raise_error: bool = True) -> tp.MaybeColumns: """Return the columns of this instance that match the provided columns.""" if self.feature_oriented: return self.resolve_symbols(columns, raise_error=raise_error) return self.resolve_features(columns, raise_error=raise_error) def concat( self, keys: tp.Optional[tp.Symbols] = None, attach_classes: bool = True, clean_index_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Union[feature_dict, symbol_dict]: """Concatenate keys along columns.""" key_wrapper = self.get_key_wrapper( keys=keys, attach_classes=attach_classes, clean_index_kwargs=clean_index_kwargs, **kwargs, ) if keys is None: keys = self.keys new_data = self.column_type() first_data = self.data[keys[0]] if key_wrapper.ndim == 1: if isinstance(first_data, pd.Series): new_data[first_data.name] = key_wrapper.wrap(first_data.values, zero_to_none=False) else: for c in first_data.columns: new_data[c] = key_wrapper.wrap(first_data[c].values, zero_to_none=False) else: if isinstance(first_data, pd.Series): columns = pd.Index([first_data.name]) else: columns = first_data.columns for c in columns: col_data = [] for k in keys: if isinstance(self.data[k], pd.Series): col_data.append(self.data[k].values) else: col_data.append(self.data[k][c].values) new_data[c] = key_wrapper.wrap(column_stack_arrays(col_data), zero_to_none=False) return new_data def get( self, features: tp.Optional[tp.MaybeFeatures] = None, symbols: tp.Optional[tp.MaybeSymbols] = None, feature: tp.Optional[tp.Feature] = None, symbol: tp.Optional[tp.Symbol] = None, squeeze_features: bool = False, squeeze_symbols: bool = False, per: str = "feature", as_dict: bool = False, **kwargs, ) -> tp.Union[tp.MaybeTuple[tp.SeriesFrame], dict]: """Get one or more features of one or more symbols of data.""" if features is not None and feature is not None: raise ValueError("Must provide either features or feature, not both") if symbols is not None and symbol is not None: raise ValueError("Must provide either symbols or symbol, not both") if feature is not None: features = feature single_feature = True else: if features is None: features = self.features single_feature = self.single_feature if single_feature: features = features[0] else: single_feature = not self.has_multiple_keys(features) if not single_feature and squeeze_features and len(features) == 1: features = features[0] single_feature = True if symbol is not None: symbols = symbol single_symbol = True else: if symbols is None: symbols = self.symbols single_symbol = self.single_symbol if single_symbol: symbols = symbols[0] else: single_symbol = not self.has_multiple_keys(symbols) if not single_symbol and squeeze_symbols and len(symbols) == 1: symbols = symbols[0] single_symbol = True if not single_feature: feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features] features = [self.features[i] for i in feature_idxs] else: feature_idxs = self.get_feature_idx(features, raise_error=True) features = self.features[feature_idxs] if not single_symbol: symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols] symbols = [self.symbols[i] for i in symbol_idxs] else: symbol_idxs = self.get_symbol_idx(symbols, raise_error=True) symbols = self.symbols[symbol_idxs] def _get_objs(): if self.feature_oriented: if single_feature: if self.single_symbol: return list(self.data.values())[feature_idxs], features return list(self.data.values())[feature_idxs].iloc[:, symbol_idxs], features if single_symbol: concat_data = self.concat(keys=features, **kwargs) return list(concat_data.values())[symbol_idxs], symbols if per.lower() in ("symbol", "column"): concat_data = self.concat(keys=features, **kwargs) return tuple([list(concat_data.values())[i] for i in symbol_idxs]), symbols if per.lower() in ("feature", "key"): if self.single_feature: if self.single_symbol: return list(self.data.values())[feature_idxs], features return list(self.data.values())[feature_idxs].iloc[:, symbol_idxs], features if self.single_symbol: return tuple([list(self.data.values())[i] for i in feature_idxs]), features return tuple([list(self.data.values())[i].iloc[:, symbol_idxs] for i in feature_idxs]), features raise ValueError(f"Invalid per: '{per}'") else: if single_symbol: if self.single_feature: return self.data[self.symbols[symbol_idxs]], symbols return self.data[self.symbols[symbol_idxs]].iloc[:, feature_idxs], symbols if single_feature: concat_data = self.concat(keys=symbols, **kwargs) return list(concat_data.values())[feature_idxs], features if per.lower() in ("feature", "column"): concat_data = self.concat(keys=symbols, **kwargs) return tuple([list(concat_data.values())[i] for i in feature_idxs]), features if per.lower() in ("symbol", "key"): if self.single_symbol: if self.single_feature: return list(self.data.values())[symbol_idxs], symbols return list(self.data.values())[symbol_idxs].iloc[:, feature_idxs], symbols if self.single_feature: return tuple([list(self.data.values())[i] for i in symbol_idxs]), symbols return tuple([list(self.data.values())[i].iloc[:, feature_idxs] for i in symbol_idxs]), symbols raise ValueError(f"Invalid per: '{per}'") objs, keys = _get_objs() if as_dict: if isinstance(objs, tuple): return dict(zip(keys, objs)) return {keys: objs} return objs # ############# Pre- and post-processing ############# # @classmethod def prepare_dt_index( cls, index: tp.Index, parse_dates: bool = False, tz_localize: tp.TimezoneLike = None, tz_convert: tp.TimezoneLike = None, force_tz_convert: bool = False, remove_tz: bool = False, ) -> tp.SeriesFrame: """Prepare datetime index. If `parse_dates` is True, will try to convert the index with an object data type into a datetime format using `vectorbtpro.utils.datetime_.prepare_dt_index`. If `tz_localize` is not None, will localize a datetime-naive index into this timezone. If `tz_convert` is not None, will convert a datetime-aware index into this timezone. If `force_tz_convert` is True, will convert regardless of whether the index is datetime-aware.""" if parse_dates: if not isinstance(index, (pd.DatetimeIndex, pd.MultiIndex)) and index.dtype == object: index = dt.prepare_dt_index(index) if isinstance(index, pd.DatetimeIndex): if index.tz is None and tz_localize is not None: index = index.tz_localize(dt.to_timezone(tz_localize)) if tz_convert is not None: if index.tz is not None or force_tz_convert: index = index.tz_convert(dt.to_timezone(tz_convert)) if remove_tz and index.tz is not None: index = index.tz_localize(None) return index @classmethod def prepare_dt_column( cls, sr: tp.Series, parse_dates: bool = False, tz_localize: tp.TimezoneLike = None, tz_convert: tp.TimezoneLike = None, force_tz_convert: bool = False, remove_tz: bool = False, ) -> tp.Series: """Prepare datetime column. See `Data.prepare_dt_index` for arguments.""" index = cls.prepare_dt_index( pd.Index(sr), parse_dates=parse_dates, tz_localize=tz_localize, tz_convert=tz_convert, force_tz_convert=force_tz_convert, remove_tz=remove_tz, ) if isinstance(index, pd.DatetimeIndex): return pd.Series(index, index=sr.index, name=sr.name) return sr @classmethod def prepare_dt( cls, obj: tp.SeriesFrame, parse_dates: tp.Union[None, bool, tp.Sequence[str]] = True, to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = True, remove_utc_tz: bool = False, ) -> tp.Frame: """Prepare datetime index and columns. If `parse_dates` is True, will try to convert any index and column with object data type into a datetime format using `vectorbtpro.utils.datetime_.prepare_dt_index`. If `parse_dates` is a list or dict, will first check whether the name of the column is among the names that are in `parse_dates`. If `to_utc` is True or `to_utc` is "index" or `to_utc` is a sequence and index name is in this sequence, will localize/convert any datetime index to the UTC timezone. If `to_utc` is True or `to_utc` is "columns" or `to_utc` is a sequence and column name is in this sequence, will localize/convert any datetime column to the UTC timezone.""" obj = obj.copy(deep=False) made_frame = False if isinstance(obj, pd.Series): obj = obj.to_frame() made_frame = True index_parse_dates = False if not isinstance(obj.index, pd.MultiIndex) and obj.index.dtype == object: if parse_dates is True: index_parse_dates = True elif checks.is_sequence(parse_dates) and obj.index.name in parse_dates: index_parse_dates = True if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "index") or (checks.is_sequence(to_utc) and obj.index.name in to_utc) ): index_tz_localize = "utc" index_tz_convert = "utc" index_remove_tz = remove_utc_tz else: index_tz_localize = None index_tz_convert = None index_remove_tz = False obj.index = cls.prepare_dt_index( obj.index, parse_dates=index_parse_dates, tz_localize=index_tz_localize, tz_convert=index_tz_convert, remove_tz=index_remove_tz, ) for column_name in obj.columns: column_parse_dates = False if obj[column_name].dtype == object: if parse_dates is True: column_parse_dates = True elif checks.is_sequence(parse_dates) and column_name in parse_dates: column_parse_dates = True elif not hasattr(obj[column_name], "dt"): continue if ( to_utc is True or (isinstance(to_utc, str) and to_utc.lower() == "columns") or (checks.is_sequence(to_utc) and column_name in to_utc) ): column_tz_localize = "utc" column_tz_convert = "utc" column_remove_tz = remove_utc_tz else: column_tz_localize = None column_tz_convert = None column_remove_tz = False obj[column_name] = cls.prepare_dt_column( obj[column_name], parse_dates=column_parse_dates, tz_localize=column_tz_localize, tz_convert=column_tz_convert, remove_tz=column_remove_tz, ) if made_frame: obj = obj.iloc[:, 0] return obj @classmethod def prepare_tzaware_index( cls, obj: tp.SeriesFrame, tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None, tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None, ) -> tp.SeriesFrame: """Prepare a timezone-aware index of a Pandas object. Uses `Data.prepare_dt_index` with `parse_dates=True` and `force_tz_convert=True`. For defaults, see `vectorbtpro._settings.data`.""" obj = obj.copy(deep=False) tz_localize = cls.resolve_base_setting(tz_localize, "tz_localize") if isinstance(tz_localize, bool): if tz_localize: raise ValueError("tz_localize cannot be True") else: tz_localize = None tz_convert = cls.resolve_base_setting(tz_convert, "tz_convert") if isinstance(tz_convert, bool): if tz_convert: raise ValueError("tz_convert cannot be True") else: tz_convert = None obj.index = cls.prepare_dt_index( obj.index, parse_dates=True, tz_localize=tz_localize, tz_convert=tz_convert, force_tz_convert=True, ) return obj @classmethod def align_index( cls, data: dict, missing: tp.Optional[str] = None, silence_warnings: tp.Optional[bool] = None, ) -> dict: """Align data to have the same index. The argument `missing` accepts the following values: * 'nan': set missing data points to NaN * 'drop': remove missing data points * 'raise': raise an error For defaults, see `vectorbtpro._settings.data`.""" missing = cls.resolve_base_setting(missing, "missing_index") silence_warnings = cls.resolve_base_setting(silence_warnings, "silence_warnings") index = None index_changed = False for k, obj in data.items(): if index is None: index = obj.index else: if not checks.is_index_equal(index, obj.index, check_names=False): if missing == "nan": if not silence_warnings: warn("Symbols have mismatching index. Setting missing data points to NaN.") index = index.union(obj.index) index_changed = True elif missing == "drop": if not silence_warnings: warn("Symbols have mismatching index. Dropping missing data points.") index = index.intersection(obj.index) index_changed = True elif missing == "raise": raise ValueError("Symbols have mismatching index") else: raise ValueError(f"Invalid missing: '{missing}'") if not index_changed: return data new_data = {k: obj.reindex(index=index) for k, obj in data.items()} return type(data)(new_data) @classmethod def align_columns( cls, data: dict, missing: tp.Optional[str] = None, silence_warnings: tp.Optional[bool] = None, ) -> dict: """Align data to have the same columns. See `Data.align_index` for `missing`.""" if len(data) == 1: return data missing = cls.resolve_base_setting(missing, "missing_columns") silence_warnings = cls.resolve_base_setting(silence_warnings, "silence_warnings") columns = None multiple_columns = False name_is_none = False columns_changed = False for k, obj in data.items(): if isinstance(obj, pd.Series): if obj.name is None: name_is_none = True obj = obj.to_frame() else: multiple_columns = True obj_columns = obj.columns if columns is None: columns = obj_columns else: if not checks.is_index_equal(columns, obj_columns, check_names=False): if missing == "nan": if not silence_warnings: warn("Symbols have mismatching columns. Setting missing data points to NaN.") columns = columns.union(obj_columns) columns_changed = True elif missing == "drop": if not silence_warnings: warn("Symbols have mismatching columns. Dropping missing data points.") columns = columns.intersection(obj_columns) columns_changed = True elif missing == "raise": raise ValueError("Symbols have mismatching columns") else: raise ValueError(f"Invalid missing: '{missing}'") if not columns_changed: return data new_data = {} for k, obj in data.items(): if isinstance(obj, pd.Series): obj = obj.to_frame() obj = obj.reindex(columns=columns) if not multiple_columns: obj = obj[columns[0]] if name_is_none: obj = obj.rename(None) new_data[k] = obj return type(data)(new_data) def switch_class( self, new_cls: tp.Type[DataT], clear_fetch_kwargs: bool = False, clear_returned_kwargs: bool = False, **kwargs, ) -> DataT: """Switch the class of the data instance.""" if clear_fetch_kwargs: new_fetch_kwargs = type(self.fetch_kwargs)({k: {} for k in self.symbols}) else: new_fetch_kwargs = copy_dict(self.fetch_kwargs) if clear_returned_kwargs: new_returned_kwargs = type(self.returned_kwargs)({k: {} for k in self.symbols}) else: new_returned_kwargs = copy_dict(self.returned_kwargs) return self.replace( cls_=new_cls, fetch_kwargs=new_fetch_kwargs, returned_kwargs=new_returned_kwargs, **kwargs, ) @classmethod def invert_data(cls, dct: tp.Dict[tp.Key, tp.SeriesFrame]) -> tp.Dict[tp.Key, tp.SeriesFrame]: """Invert data by swapping keys and columns.""" if len(dct) == 0: return dct new_dct = dict() for k, v in dct.items(): if isinstance(v, pd.Series): if v.name not in new_dct: new_dct[v.name] = [] new_dct[v.name].append(v.rename(k)) else: for c in v.columns: if c not in new_dct: new_dct[c] = [] new_dct[c].append(v[c].rename(k)) new_dct2 = {} for k, v in new_dct.items(): if len(v) == 1: new_dct2[k] = v[0] else: new_dct2[k] = pd.concat(v, axis=1) if isinstance(dct, symbol_dict): return feature_dict(new_dct2) if isinstance(dct, feature_dict): return symbol_dict(new_dct2) return new_dct2 @hybrid_method def align_data( cls_or_self, data: dict, last_index: tp.Union[None, feature_dict, symbol_dict] = None, delisted: tp.Union[None, feature_dict, symbol_dict] = None, tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None, tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None, missing_index: tp.Optional[str] = None, missing_columns: tp.Optional[str] = None, silence_warnings: tp.Optional[bool] = None, ) -> dict: """Align data. Removes any index duplicates, prepares the datetime index, and aligns the index and columns.""" if last_index is None: last_index = {} if delisted is None: delisted = {} if tz_localize is None and not isinstance(cls_or_self, type): tz_localize = cls_or_self.tz_localize if tz_convert is None and not isinstance(cls_or_self, type): tz_convert = cls_or_self.tz_convert if missing_index is None and not isinstance(cls_or_self, type): missing_index = cls_or_self.missing_index if missing_columns is None and not isinstance(cls_or_self, type): missing_columns = cls_or_self.missing_columns data = type(data)(data) for k, obj in data.items(): obj = to_pd_array(obj) obj = cls_or_self.prepare_tzaware_index(obj, tz_localize=tz_localize, tz_convert=tz_convert) if obj.index.is_monotonic_decreasing: obj = obj.iloc[::-1] elif not obj.index.is_monotonic_increasing: obj = obj.sort_index() if obj.index.has_duplicates: obj = obj[~obj.index.duplicated(keep="last")] data[k] = obj if (isinstance(data, symbol_dict) and isinstance(last_index, symbol_dict)) or ( isinstance(data, feature_dict) and isinstance(last_index, feature_dict) ): if k not in last_index: last_index[k] = obj.index[-1] if (isinstance(data, symbol_dict) and isinstance(delisted, symbol_dict)) or ( isinstance(data, feature_dict) and isinstance(delisted, feature_dict) ): if k not in delisted: delisted[k] = False data = cls_or_self.align_index(data, missing=missing_index, silence_warnings=silence_warnings) data = cls_or_self.align_columns(data, missing=missing_columns, silence_warnings=silence_warnings) first_data = data[list(data.keys())[0]] if isinstance(first_data, pd.Series): columns = [first_data.name] else: columns = first_data.columns for k in columns: if (isinstance(data, symbol_dict) and isinstance(last_index, feature_dict)) or ( isinstance(data, feature_dict) and isinstance(last_index, symbol_dict) ): if k not in last_index: last_index[k] = first_data.index[-1] if (isinstance(data, symbol_dict) and isinstance(delisted, feature_dict)) or ( isinstance(data, feature_dict) and isinstance(delisted, symbol_dict) ): if k not in delisted: delisted[k] = False for obj in data.values(): if isinstance(obj.index, pd.DatetimeIndex): obj.index.freq = obj.index.inferred_freq return data @classmethod def from_data( cls: tp.Type[DataT], data: tp.Union[dict, tp.SeriesFrame], columns_are_symbols: bool = False, invert_data: bool = False, single_key: bool = True, classes: tp.Optional[dict] = None, level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None, tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None, missing_index: tp.Optional[str] = None, missing_columns: tp.Optional[str] = None, wrapper_kwargs: tp.KwargsLike = None, fetch_kwargs: tp.Optional[dict] = None, returned_kwargs: tp.Optional[dict] = None, last_index: tp.Optional[dict] = None, delisted: tp.Optional[dict] = None, silence_warnings: tp.Optional[bool] = None, **kwargs, ) -> DataT: """Create a new `Data` instance from data. Args: data (dict): Dictionary of array-like objects keyed by symbol. columns_are_symbols (bool): Whether columns in each DataFrame are symbols. invert_data (bool): Whether to invert the data dictionary with `Data.invert_data`. single_key (bool): See `Data.single_key`. classes (feature_dict or symbol_dict): See `Data.classes`. level_name (bool, hashable or iterable of hashable): See `Data.level_name`. tz_localize (timezone_like): See `Data.prepare_tzaware_index`. tz_convert (timezone_like): See `Data.prepare_tzaware_index`. missing_index (str): See `Data.align_index`. missing_columns (str): See `Data.align_columns`. wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`. fetch_kwargs (feature_dict or symbol_dict): Keyword arguments initially passed to `Data.fetch_symbol`. returned_kwargs (feature_dict or symbol_dict): Keyword arguments returned by `Data.fetch_symbol`. last_index (feature_dict or symbol_dict): Last fetched index per symbol. delisted (feature_dict or symbol_dict): Whether symbol has been delisted. silence_warnings (bool): Whether to silence all warnings. **kwargs: Keyword arguments passed to the `__init__` method. For defaults, see `vectorbtpro._settings.data`.""" if wrapper_kwargs is None: wrapper_kwargs = {} if classes is None: classes = {} if fetch_kwargs is None: fetch_kwargs = {} if returned_kwargs is None: returned_kwargs = {} if last_index is None: last_index = {} if delisted is None: delisted = {} if columns_are_symbols and isinstance(data, symbol_dict): raise TypeError("Data cannot have the type symbol_dict when columns_are_symbols=True") if isinstance(data, (pd.Series, pd.DataFrame)): if columns_are_symbols: data = feature_dict(feature=data) else: data = symbol_dict(symbol=data) checks.assert_instance_of(data, dict, arg_name="data") if not isinstance(data, key_dict): if columns_are_symbols: data = feature_dict(data) else: data = symbol_dict(data) if invert_data: data = cls.invert_data(data) if len(data) > 1: single_key = False checks.assert_instance_of(last_index, dict, arg_name="last_index") if not isinstance(last_index, key_dict): last_index = type(data)(last_index) checks.assert_instance_of(delisted, dict, arg_name="delisted") if not isinstance(delisted, key_dict): delisted = type(data)(delisted) data = cls.align_data( data, last_index=last_index, delisted=delisted, tz_localize=tz_localize, tz_convert=tz_convert, missing_index=missing_index, missing_columns=missing_columns, silence_warnings=silence_warnings, ) first_data = data[list(data.keys())[0]] wrapper = ArrayWrapper.from_obj(first_data, **wrapper_kwargs) attr_dicts = cls.fix_dict_types_in_kwargs( type(data), classes=classes, fetch_kwargs=fetch_kwargs, returned_kwargs=returned_kwargs, last_index=last_index, delisted=delisted, ) return cls( wrapper, data, single_key=single_key, level_name=level_name, tz_localize=tz_localize, tz_convert=tz_convert, missing_index=missing_index, missing_columns=missing_columns, **attr_dicts, **kwargs, ) def invert(self: DataT, key_wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Invert data and return a new instance.""" if key_wrapper_kwargs is None: key_wrapper_kwargs = {} new_data = self.concat(attach_classes=False) if "wrapper" not in kwargs: kwargs["wrapper"] = self.get_key_wrapper(**key_wrapper_kwargs) if "classes" not in kwargs: kwargs["classes"] = self.column_type() if "single_key" not in kwargs: kwargs["single_key"] = self.wrapper.ndim == 1 if "level_name" not in kwargs: if isinstance(self.wrapper.columns, pd.MultiIndex): if self.wrapper.columns.names == [None] * self.wrapper.columns.nlevels: kwargs["level_name"] = False else: kwargs["level_name"] = self.wrapper.columns.names else: if self.wrapper.columns.name is None: kwargs["level_name"] = False else: kwargs["level_name"] = self.wrapper.columns.name return self.replace(data=new_data, **kwargs) def to_feature_oriented(self: DataT, **kwargs) -> DataT: """Convert this instance to the feature-oriented format. Returns self if the data is already properly formatted.""" if self.feature_oriented: if len(kwargs) > 0: return self.replace(**kwargs) return self return self.invert(**kwargs) def to_symbol_oriented(self: DataT, **kwargs) -> DataT: """Convert this instance to the symbol-oriented format. Returns self if the data is already properly formatted.""" if self.symbol_oriented: if len(kwargs) > 0: return self.replace(**kwargs) return self return self.invert(**kwargs) @classmethod def has_key_dict( cls, arg: tp.Any, dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None, ) -> bool: """Check whether the argument contains any data dictionary.""" if dict_type is None: dict_type = key_dict if isinstance(arg, dict_type): return True if isinstance(arg, dict): for k, v in arg.items(): if isinstance(v, dict_type): return True return False @hybrid_method def check_dict_type( cls_or_self, arg: tp.Any, arg_name: tp.Optional[str] = None, dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None, ) -> None: """Check whether the argument conforms to a data dictionary.""" if isinstance(cls_or_self, type): checks.assert_not_none(dict_type, arg_name="dict_type") if dict_type is None: dict_type = cls_or_self.dict_type if issubclass(dict_type, feature_dict): checks.assert_not_instance_of(arg, symbol_dict, arg_name=arg_name) if issubclass(dict_type, symbol_dict): checks.assert_not_instance_of(arg, feature_dict, arg_name=arg_name) @hybrid_method def select_key_kwargs( cls_or_self, key: tp.Key, kwargs: tp.KwargsLike, kwargs_name: str = "kwargs", dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None, check_dict_type: bool = True, ) -> tp.Kwargs: """Select the keyword arguments belonging to a feature or symbol.""" if isinstance(cls_or_self, type): checks.assert_not_none(dict_type, arg_name="dict_type") if dict_type is None: dict_type = cls_or_self.dict_type if kwargs is None: return {} if check_dict_type: cls_or_self.check_dict_type(kwargs, arg_name=kwargs_name, dict_type=dict_type) if type(kwargs) is key_dict or isinstance(kwargs, dict_type): if key not in kwargs: return {} kwargs = dict(kwargs[key]) _kwargs = {} for k, v in kwargs.items(): if check_dict_type: cls_or_self.check_dict_type(v, arg_name=f"{kwargs_name}[{k}]", dict_type=dict_type) if type(v) is key_dict or isinstance(v, dict_type): if key in v: _kwargs[k] = v[key] else: _kwargs[k] = v return _kwargs @classmethod def select_feature_kwargs(cls, feature: tp.Feature, kwargs: tp.KwargsLike, **kwargs_) -> tp.Kwargs: """Select the keyword arguments belonging to a feature.""" return cls.select_key_kwargs(feature, kwargs, dict_type=feature_dict, **kwargs_) @classmethod def select_symbol_kwargs(cls, symbol: tp.Symbol, kwargs: tp.KwargsLike, **kwargs_) -> tp.Kwargs: """Select the keyword arguments belonging to a symbol.""" return cls.select_key_kwargs(symbol, kwargs, dict_type=symbol_dict, **kwargs_) @hybrid_method def select_key_from_dict( cls_or_self, key: tp.Key, dct: key_dict, dct_name: str = "dct", dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None, check_dict_type: bool = True, ) -> tp.Any: """Select the dictionary value belonging to a feature or symbol.""" if isinstance(cls_or_self, type): checks.assert_not_none(dict_type, arg_name="dict_type") if dict_type is None: dict_type = cls_or_self.dict_type if check_dict_type: cls_or_self.check_dict_type(dct, arg_name=dct_name, dict_type=dict_type) return dct[key] @classmethod def select_feature_from_dict(cls, feature: tp.Feature, dct: feature_dict, **kwargs) -> tp.Any: """Select the dictionary value belonging to a feature.""" return cls.select_key_kwargs(feature, dct, dict_type=feature_dict, **kwargs) @classmethod def select_symbol_from_dict(cls, symbol: tp.Symbol, dct: symbol_dict, **kwargs) -> tp.Any: """Select the dictionary value belonging to a symbol.""" return cls.select_key_kwargs(symbol, dct, dict_type=symbol_dict, **kwargs) @classmethod def select_from_dict(cls, dct: dict, keys: tp.Keys, raise_error: bool = False) -> dict: """Select keys from a dict.""" if raise_error: return type(dct)({k: dct[k] for k in keys}) return type(dct)({k: dct[k] for k in keys if k in dct}) @classmethod def get_intersection_dict(cls, dct: dict) -> dict: """Get sub-keys and corresponding sub-values that are the same for all keys.""" dct_values = list(dct.values()) overlapping_keys = set(dct_values[0].keys()) for d in dct_values[1:]: overlapping_keys.intersection_update(d.keys()) new_dct = dict() for i, k in enumerate(dct.keys()): for k2 in overlapping_keys: v2 = dct[k][k2] if i == 0 and k2 not in new_dct: new_dct[k2] = v2 elif k2 in new_dct and new_dct[k2] is not v2: del new_dct[k2] return new_dct def select_keys(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT: """Create a new `Data` instance with one or more keys selected from this instance.""" keys = self.resolve_keys(keys) if self.has_multiple_keys(keys): single_key = False else: single_key = True keys = [keys] attr_dicts = dict() for attr in self._key_dict_attrs: attr_value = getattr(self, attr) if isinstance(attr_value, self.dict_type): attr_dicts[attr] = self.select_from_dict(attr_value, keys) return self.replace( data=self.select_from_dict(self.data, keys, raise_error=True), single_key=single_key, **attr_dicts, **kwargs, ) def select_columns(self: DataT, columns: tp.MaybeColumns, **kwargs) -> DataT: """Create a new `Data` instance with one or more columns selected from this instance.""" columns = self.resolve_columns(columns) def _pd_indexing_func(obj): return obj[columns] return self.indexing_func(_pd_indexing_func, replace_kwargs=kwargs) def select_feature_idxs(self: DataT, idxs: tp.MaybeSequence[int], **kwargs) -> DataT: if checks.is_int(idxs): features = self.features[idxs] else: features = [self.features[i] for i in idxs] if self.feature_oriented: return self.select_keys(features, **kwargs) return self.select_columns(features, **kwargs) def select_symbol_idxs(self: DataT, idxs: tp.MaybeSequence[int], **kwargs) -> DataT: if checks.is_int(idxs): symbols = self.symbols[idxs] else: symbols = [self.symbols[i] for i in idxs] if self.feature_oriented: return self.select_columns(symbols, **kwargs) return self.select_keys(symbols, **kwargs) def select(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT: """Create a new `Data` instance with one or more features or symbols selected from this instance. Will try to determine the orientation automatically.""" if not self.has_multiple_keys(keys): keys = [keys] single_key = True else: single_key = False feature_keys = set(self.resolve_features(keys, raise_error=False)) symbol_keys = set(self.resolve_symbols(keys, raise_error=False)) features_and_keys = set(self.features).intersection(feature_keys) symbols_and_keys = set(self.symbols).intersection(symbol_keys) if features_and_keys and not symbols_and_keys: if single_key: return self.select_features(keys[0], **kwargs) return self.select_features(keys, **kwargs) if symbols_and_keys and not features_and_keys: if single_key: return self.select_symbols(keys[0], **kwargs) return self.select_symbols(keys, **kwargs) raise ValueError("Cannot determine orientation. Use select_features or select_symbols.") def add_feature( self: DataT, feature: tp.Feature, data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None, pull_feature: bool = False, pull_kwargs: tp.KwargsLike = None, reuse_fetch_kwargs: bool = True, run_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, merge_kwargs: tp.KwargsLike = None, **kwargs, ) -> DataT: """Create a new `Data` instance with a new feature added to this instance. If `data` is None, uses `Data.run`. If in addition `pull_feature` is True, uses `Data.pull` instead.""" if run_kwargs is None: run_kwargs = {} if wrap_kwargs is None: wrap_kwargs = {} if data is None: if pull_feature: if isinstance(self.fetch_kwargs, feature_dict) and reuse_fetch_kwargs: pull_kwargs = merge_dicts(self.get_intersection_dict(self.fetch_kwargs), pull_kwargs) data = type(self).pull(features=feature, **pull_kwargs).get(feature=feature) else: data = self.run(feature, **run_kwargs, unpack=True) data = self.symbol_wrapper.wrap(data, **wrap_kwargs) if isinstance(data, CustomTemplate): data = data.substitute(dict(data=self), eval_id="data") if isinstance(data, pd.Series) and self.symbol_wrapper.ndim == 1: data = data.copy(deep=False) data.name = self.symbols[0] for attr in self._key_dict_attrs: if attr in kwargs: checks.assert_not_instance_of(kwargs[attr], key_dict, arg_name=attr) kwargs[attr] = feature_dict({feature: kwargs[attr]}) data = type(self).from_data( feature_dict({feature: data}), invert_data=not self.feature_oriented, **kwargs, ) on_merge_conflict = {k: "error" for k in kwargs if k not in self._key_dict_attrs} on_merge_conflict["_def"] = "first" if merge_kwargs is None: merge_kwargs = {} return self.merge(data, on_merge_conflict=on_merge_conflict, **merge_kwargs) def add_symbol( self: DataT, symbol: tp.Symbol, data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None, pull_kwargs: tp.KwargsLike = None, reuse_fetch_kwargs: bool = True, merge_kwargs: tp.KwargsLike = None, **kwargs, ) -> DataT: """Create a new `Data` instance with a new symbol added to this instance. If `data` is None, uses `Data.pull`.""" if pull_kwargs is None: pull_kwargs = {} if data is None: if isinstance(self.fetch_kwargs, symbol_dict) and reuse_fetch_kwargs: pull_kwargs = merge_dicts(self.get_intersection_dict(self.fetch_kwargs), pull_kwargs) data = type(self).pull(symbols=symbol, **pull_kwargs).get(symbol=symbol) if isinstance(data, CustomTemplate): data = data.substitute(dict(data=self), eval_id="data") if isinstance(data, pd.Series) and self.feature_wrapper.ndim == 1: data = data.copy(deep=False) data.name = self.features[0] for attr in self._key_dict_attrs: if attr in kwargs: checks.assert_not_instance_of(kwargs[attr], key_dict, arg_name=attr) kwargs[attr] = symbol_dict({symbol: kwargs[attr]}) data = type(self).from_data( symbol_dict({symbol: data}), invert_data=not self.symbol_oriented, **kwargs, ) on_merge_conflict = {k: "error" for k in kwargs if k not in self._key_dict_attrs} on_merge_conflict["_def"] = "first" if merge_kwargs is None: merge_kwargs = {} return self.merge(data, on_merge_conflict=on_merge_conflict, **merge_kwargs) def add_key( self: DataT, key: tp.Key, data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with a new key added to this instance.""" if self.feature_oriented: return self.add_feature(key, data=data, **kwargs) return self.add_symbol(key, data=data, **kwargs) def add_column( self: DataT, column: tp.Column, data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with a new column added to this instance.""" if self.feature_oriented: return self.add_symbol(column, data=data, **kwargs) return self.add_feature(column, data=data, **kwargs) def add( self: DataT, key: tp.Key, data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with a new feature or symbol added to this instance. Will try to determine the orientation automatically.""" if data is not None: if isinstance(data, CustomTemplate): data = data.substitute(dict(data=self), eval_id="data") if isinstance(data, pd.Series): columns = [data.name] else: columns = data.columns feature_columns = set(self.resolve_features(columns, raise_error=False)) symbol_columns = set(self.resolve_symbols(columns, raise_error=False)) features_and_columns = set(self.features).intersection(feature_columns) symbols_and_columns = set(self.symbols).intersection(symbol_columns) if features_and_columns and not symbols_and_columns: return self.add_symbol(key, data=data, **kwargs) if symbols_and_columns and not features_and_columns: return self.add_feature(key, data=data, **kwargs) raise ValueError("Cannot determine orientation. Use add_feature or add_symbol.") @classmethod def rename_in_dict(cls, dct: dict, rename: tp.Dict[tp.Key, tp.Key]) -> dict: """Rename keys in a dict.""" return type(dct)({rename.get(k, k): v for k, v in dct.items()}) def rename_keys( self: DataT, rename: tp.Union[tp.MaybeKeys, tp.Dict[tp.Key, tp.Key]], to: tp.Optional[tp.MaybeKeys] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with keys renamed.""" if to is not None: if self.has_multiple_keys(to): rename = dict(zip(rename, to)) else: rename = {rename: to} rename = dict(zip(self.resolve_keys(list(rename.keys())), rename.values())) attr_dicts = dict() for attr in self._key_dict_attrs: attr_value = getattr(self, attr) if isinstance(attr_value, self.dict_type): attr_dicts[attr] = self.rename_in_dict(attr_value, rename) return self.replace(data=self.rename_in_dict(self.data, rename), **attr_dicts, **kwargs) def rename_columns( self: DataT, rename: tp.Union[tp.MaybeColumns, tp.Dict[tp.Column, tp.Column]], to: tp.Optional[tp.MaybeColumns] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with columns renamed.""" if to is not None: if self.has_multiple_keys(to): rename = dict(zip(rename, to)) else: rename = {rename: to} rename = dict(zip(self.resolve_columns(list(rename.keys())), rename.values())) attr_dicts = dict() for attr in self._key_dict_attrs: attr_value = getattr(self, attr) if isinstance(attr_value, self.column_type): attr_dicts[attr] = self.rename_in_dict(attr_value, rename) new_wrapper = self.wrapper.replace(columns=self.wrapper.columns.map(lambda x: rename.get(x, x))) return self.replace(wrapper=new_wrapper, **attr_dicts, **kwargs) def rename_features( self: DataT, rename: tp.Union[tp.MaybeFeatures, tp.Dict[tp.Feature, tp.Feature]], to: tp.Optional[tp.MaybeFeatures] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with features renamed.""" if self.feature_oriented: return self.rename_keys(rename, to=to, **kwargs) return self.rename_columns(rename, to=to, **kwargs) def rename_symbols( self: DataT, rename: tp.Union[tp.MaybeSymbols, tp.Dict[tp.Symbol, tp.Symbol]], to: tp.Optional[tp.MaybeSymbols] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with symbols renamed.""" if self.feature_oriented: return self.rename_columns(rename, to=to, **kwargs) return self.rename_keys(rename, to=to, **kwargs) def rename( self: DataT, rename: tp.Union[tp.MaybeKeys, tp.Dict[tp.Key, tp.Key]], to: tp.Optional[tp.MaybeKeys] = None, **kwargs, ) -> DataT: """Create a new `Data` instance with features or symbols renamed. Will try to determine the orientation automatically.""" if to is not None: if self.has_multiple_keys(to): rename = dict(zip(rename, to)) else: rename = {rename: to} feature_keys = set(self.resolve_features(list(rename.keys()), raise_error=False)) symbol_keys = set(self.resolve_symbols(list(rename.keys()), raise_error=False)) features_and_keys = set(self.features).intersection(feature_keys) symbols_and_keys = set(self.symbols).intersection(symbol_keys) if features_and_keys and not symbols_and_keys: return self.rename_features(rename, **kwargs) if symbols_and_keys and not features_and_keys: return self.rename_symbols(rename, **kwargs) raise ValueError("Cannot determine orientation. Use rename_features or rename_symbols.") def remove_features(self: DataT, features: tp.MaybeFeatures, **kwargs) -> DataT: """Create a new `Data` instance with one or more features removed from this instance.""" if self.has_multiple_keys(features): remove_feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features] else: remove_feature_idxs = [self.get_feature_idx(features, raise_error=True)] keep_feature_idxs = [i for i in range(len(self.features)) if i not in remove_feature_idxs] if len(keep_feature_idxs) == 0: raise ValueError("No features will be left after this operation") return self.select_feature_idxs(keep_feature_idxs, **kwargs) def remove_symbols(self: DataT, symbols: tp.MaybeFeatures, **kwargs) -> DataT: """Create a new `Data` instance with one or more symbols removed from this instance.""" if self.has_multiple_keys(symbols): remove_symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols] else: remove_symbol_idxs = [self.get_symbol_idx(symbols, raise_error=True)] keep_symbol_idxs = [i for i in range(len(self.symbols)) if i not in remove_symbol_idxs] if len(keep_symbol_idxs) == 0: raise ValueError("No symbols will be left after this operation") return self.select_symbol_idxs(keep_symbol_idxs, **kwargs) def remove_keys(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT: """Create a new `Data` instance with one or more keys removed from this instance.""" if self.feature_oriented: return self.remove_features(keys, **kwargs) return self.remove_symbols(keys, **kwargs) def remove_columns(self: DataT, columns: tp.MaybeColumns, **kwargs) -> DataT: """Create a new `Data` instance with one or more columns removed from this instance.""" if self.feature_oriented: return self.remove_symbols(columns, **kwargs) return self.remove_features(columns, **kwargs) def remove(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT: """Create a new `Data` instance with one or more features or symbols removed from this instance. Will try to determine the orientation automatically.""" if not self.has_multiple_keys(keys): keys = [keys] feature_keys = set(self.resolve_features(keys, raise_error=False)) symbol_keys = set(self.resolve_symbols(keys, raise_error=False)) features_and_keys = set(self.features).intersection(feature_keys) symbols_and_keys = set(self.symbols).intersection(symbol_keys) if features_and_keys and not symbols_and_keys: return self.remove_features(keys, **kwargs) if symbols_and_keys and not features_and_keys: return self.remove_symbols(keys, **kwargs) raise ValueError("Cannot determine orientation. Use remove_features or remove_symbols.") @hybrid_method def merge( cls_or_self: tp.MaybeType[DataT], *datas: DataT, rename: tp.Optional[tp.Dict[tp.Key, tp.Key]] = None, **kwargs, ) -> DataT: """Merge multiple `Data` instances. Can merge both symbols and features. Data is overridden in the order as provided in `datas`.""" if len(datas) == 1 and not isinstance(datas[0], Data): datas = datas[0] datas = list(datas) if not isinstance(cls_or_self, type): datas = (cls_or_self, *datas) data_type = None data = {} single_key = True attr_dicts = dict() for instance in datas: if data_type is None: data_type = type(instance.data) elif not isinstance(instance.data, data_type): raise TypeError("Objects to be merged must have the same dict type for data") if not instance.single_key: single_key = False for k in instance.keys: if rename is None: new_k = k else: new_k = rename[k] if new_k in data: obj1 = instance.data[k] obj2 = data[new_k] both_were_series = True if isinstance(obj1, pd.Series): obj1 = obj1.to_frame() else: both_were_series = False if isinstance(obj2, pd.Series): obj2 = obj2.to_frame() else: both_were_series = False new_obj = obj1.combine_first(obj2) new_columns = [] for c in obj2.columns: new_columns.append(c) for c in obj1.columns: if c not in new_columns: new_columns.append(c) new_obj = new_obj[new_columns] if new_obj.shape[1] == 1 and both_were_series: new_obj = new_obj.iloc[:, 0] data[new_k] = new_obj else: data[new_k] = instance.data[k] for attr in cls_or_self._key_dict_attrs: attr_value = getattr(instance, attr) if (issubclass(data_type, symbol_dict) and isinstance(attr_value, symbol_dict)) or ( issubclass(data_type, feature_dict) and isinstance(attr_value, feature_dict) ): if k in attr_value: if attr not in attr_dicts: attr_dicts[attr] = type(attr_value)() elif not isinstance(attr_value, type(attr_dicts[attr])): raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'") attr_dicts[attr][new_k] = attr_value[k] for attr in cls_or_self._key_dict_attrs: attr_value = getattr(instance, attr) if (issubclass(data_type, symbol_dict) and isinstance(attr_value, feature_dict)) or ( issubclass(data_type, feature_dict) and isinstance(attr_value, symbol_dict) ): if attr not in attr_dicts: attr_dicts[attr] = type(attr_value)() elif not isinstance(attr_value, type(attr_dicts[attr])): raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'") attr_dicts[attr].update(**attr_value) if "missing_index" not in kwargs: kwargs["missing_index"] = "nan" if "missing_columns" not in kwargs: kwargs["missing_columns"] = "nan" kwargs = cls_or_self.resolve_merge_kwargs( *[instance.config for instance in datas], wrapper=None, data=data_type(data), single_key=single_key, **attr_dicts, **kwargs, ) kwargs.pop("wrapper", None) return cls_or_self.from_data(**kwargs) # ############# Fetching ############# # @classmethod def fetch_feature( cls, feature: tp.Feature, **kwargs, ) -> tp.FeatureData: """Fetch a feature. Can also return a dictionary that will be accessible in `Data.returned_kwargs`. If there are keyword arguments `tz_localize`, `tz_convert`, or `freq` in this dict, will pop them and use them to override global settings. This is an abstract method - override it to define custom logic.""" raise NotImplementedError @classmethod def try_fetch_feature( cls, feature: tp.Feature, skip_on_error: bool = False, silence_warnings: bool = False, fetch_kwargs: tp.KwargsLike = None, ) -> tp.FeatureData: """Try to fetch a feature using `Data.fetch_feature`.""" if fetch_kwargs is None: fetch_kwargs = {} try: out = cls.fetch_feature(feature, **fetch_kwargs) if out is None: if not silence_warnings: warn(f"Feature '{str(feature)}' returned None. Skipping.") return out except Exception as e: if not skip_on_error: raise e if not silence_warnings: warn(traceback.format_exc()) warn(f"Feature '{str(feature)}' raised an exception. Skipping.") return None @classmethod def fetch_symbol( cls, symbol: tp.Symbol, **kwargs, ) -> tp.SymbolData: """Fetch a symbol. Can also return a dictionary that will be accessible in `Data.returned_kwargs`. If there are keyword arguments `tz_localize`, `tz_convert`, or `freq` in this dict, will pop them and use them to override global settings. This is an abstract method - override it to define custom logic.""" raise NotImplementedError @classmethod def try_fetch_symbol( cls, symbol: tp.Symbol, skip_on_error: bool = False, silence_warnings: bool = False, fetch_kwargs: tp.KwargsLike = None, ) -> tp.SymbolData: """Try to fetch a symbol using `Data.fetch_symbol`.""" if fetch_kwargs is None: fetch_kwargs = {} try: out = cls.fetch_symbol(symbol, **fetch_kwargs) if out is None: if not silence_warnings: warn(f"Symbol '{str(symbol)}' returned None. Skipping.") return out except Exception as e: if not skip_on_error: raise e if not silence_warnings: warn(traceback.format_exc()) warn(f"Symbol '{str(symbol)}' raised an exception. Skipping.") return None @classmethod def resolve_keys_meta( cls, keys: tp.Union[None, dict, tp.MaybeKeys] = None, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, ) -> tp.Kwargs: """Resolve metadata for keys.""" if keys is not None and features is not None: raise ValueError("Must provide either keys or features, not both") if keys is not None and symbols is not None: raise ValueError("Must provide either keys or symbols, not both") if features is not None and symbols is not None: raise ValueError("Must provide either features or symbols, not both") if keys is None: if features is not None: if isinstance(features, dict): cls.check_dict_type(features, "features", dict_type=feature_dict) keys = features keys_are_features = True dict_type = feature_dict elif symbols is not None: if isinstance(symbols, dict): cls.check_dict_type(symbols, "symbols", dict_type=symbol_dict) keys = symbols keys_are_features = False dict_type = symbol_dict else: keys = symbols keys_are_features = False dict_type = symbol_dict else: if isinstance(keys, feature_dict): if keys_are_features is not None and not keys_are_features: raise TypeError("Keys are of type feature_dict but keys_are_features is False") keys_are_features = True elif isinstance(keys, symbol_dict): if keys_are_features is not None and keys_are_features: raise TypeError("Keys are of type symbol_dict but keys_are_features is True") keys_are_features = False keys_are_features = cls.resolve_base_setting(keys_are_features, "keys_are_features") if keys_are_features: dict_type = feature_dict else: dict_type = symbol_dict return dict( keys=keys, keys_are_features=keys_are_features, dict_type=dict_type, ) @classmethod def pull( cls: tp.Type[DataT], keys: tp.Union[None, dict, tp.MaybeKeys] = None, *, keys_are_features: tp.Optional[bool] = None, features: tp.Union[None, dict, tp.MaybeFeatures] = None, symbols: tp.Union[None, dict, tp.MaybeSymbols] = None, classes: tp.Optional[tp.MaybeSequence[tp.Union[tp.Hashable, dict]]] = None, level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None, tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None, tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None, missing_index: tp.Optional[str] = None, missing_columns: tp.Optional[str] = None, wrapper_kwargs: tp.KwargsLike = None, skip_on_error: tp.Optional[bool] = None, silence_warnings: tp.Optional[bool] = None, execute_kwargs: tp.KwargsLike = None, return_raw: bool = False, **kwargs, ) -> tp.Union[DataT, tp.List[tp.Any]]: """Pull data. Fetches each feature/symbol with `Data.fetch_feature`/`Data.fetch_symbol` and prepares it with `Data.from_data`. Iteration over features/symbols is done using `vectorbtpro.utils.execution.execute`. That is, it can be distributed and parallelized when needed. Args: keys (hashable, sequence of hashable, or dict): One or multiple keys. Depending on `keys_are_features` will be set to `features` or `symbols`. keys_are_features (bool): Whether `keys` are considered features. features (hashable, sequence of hashable, or dict): One or multiple features. If provided as a dictionary, will use keys as features and values as keyword arguments. !!! note Tuple is considered as a single feature (tuple is a hashable). symbols (hashable, sequence of hashable, or dict): One or multiple symbols. If provided as a dictionary, will use keys as symbols and values as keyword arguments. !!! note Tuple is considered as a single symbol (tuple is a hashable). classes (feature_dict or symbol_dict): See `Data.classes`. Can be a hashable (single value), a dictionary (class names as keys and class values as values), or a sequence of such. !!! note Tuple is considered as a single class (tuple is a hashable). level_name (bool, hashable or iterable of hashable): See `Data.level_name`. tz_localize (any): See `Data.from_data`. tz_convert (any): See `Data.from_data`. missing_index (str): See `Data.from_data`. missing_columns (str): See `Data.from_data`. wrapper_kwargs (dict): See `Data.from_data`. skip_on_error (bool): Whether to skip the feature/symbol when an exception is raised. silence_warnings (bool): Whether to silence all warnings. Will also forward this argument to `Data.fetch_feature`/`Data.fetch_symbol` if in the signature. execute_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.execution.execute`. return_raw (bool): Whether to return the raw outputs. **kwargs: Passed to `Data.fetch_feature`/`Data.fetch_symbol`. If two features/symbols require different keyword arguments, pass `key_dict` or `feature_dict`/`symbol_dict` for each argument. For defaults, see `vectorbtpro._settings.data`. """ keys_meta = cls.resolve_keys_meta( keys=keys, keys_are_features=keys_are_features, features=features, symbols=symbols, ) keys = keys_meta["keys"] keys_are_features = keys_meta["keys_are_features"] dict_type = keys_meta["dict_type"] fetch_kwargs = dict_type() if isinstance(keys, dict): new_keys = [] for k, key_fetch_kwargs in keys.items(): new_keys.append(k) fetch_kwargs[k] = key_fetch_kwargs keys = new_keys single_key = False elif cls.has_multiple_keys(keys): keys = list(keys) if len(set(keys)) < len(keys): raise ValueError("Duplicate keys provided") single_key = False else: single_key = True keys = [keys] if classes is not None: cls.check_dict_type(classes, arg_name="classes", dict_type=dict_type) if not isinstance(classes, key_dict): new_classes = {} single_class = checks.is_hashable(classes) or isinstance(classes, dict) if single_class: for k in keys: if isinstance(classes, dict): new_classes[k] = classes else: if keys_are_features: new_classes[k] = {"feature_class": classes} else: new_classes[k] = {"symbol_class": classes} else: for i, k in enumerate(keys): _classes = classes[i] if not isinstance(_classes, dict): if keys_are_features: _classes = {"feature_class": _classes} else: _classes = {"symbol_class": _classes} new_classes[k] = _classes classes = new_classes wrapper_kwargs = cls.resolve_base_setting(wrapper_kwargs, "wrapper_kwargs", merge=True) skip_on_error = cls.resolve_base_setting(skip_on_error, "skip_on_error") silence_warnings = cls.resolve_base_setting(silence_warnings, "silence_warnings") execute_kwargs = cls.resolve_base_setting(execute_kwargs, "execute_kwargs", merge=True) execute_kwargs = merge_dicts(dict(show_progress=not single_key), execute_kwargs) tasks = [] if keys_are_features: func_arg_names = get_func_arg_names(cls.fetch_feature) else: func_arg_names = get_func_arg_names(cls.fetch_symbol) for k in keys: if keys_are_features: key_fetch_func = cls.try_fetch_feature key_fetch_kwargs = cls.select_feature_kwargs(k, kwargs) else: key_fetch_func = cls.try_fetch_symbol key_fetch_kwargs = cls.select_symbol_kwargs(k, kwargs) if "silence_warnings" in func_arg_names: key_fetch_kwargs["silence_warnings"] = silence_warnings if k in fetch_kwargs: key_fetch_kwargs = merge_dicts(key_fetch_kwargs, fetch_kwargs[k]) tasks.append( Task( key_fetch_func, k, skip_on_error=skip_on_error, silence_warnings=silence_warnings, fetch_kwargs=key_fetch_kwargs, ) ) fetch_kwargs[k] = key_fetch_kwargs key_index = cls.get_key_index(keys=keys, level_name=level_name, feature_oriented=keys_are_features) outputs = execute(tasks, size=len(keys), keys=key_index, **execute_kwargs) if return_raw: return outputs data = dict_type() returned_kwargs = dict_type() common_tz_localize = None common_tz_convert = None common_freq = None for i, out in enumerate(outputs): k = keys[i] if out is not None: if isinstance(out, tuple): _data = out[0] _returned_kwargs = out[1] else: _data = out _returned_kwargs = {} _data = to_any_array(_data) _tz = _returned_kwargs.pop("tz", None) _tz_localize = _returned_kwargs.pop("tz_localize", None) _tz_convert = _returned_kwargs.pop("tz_convert", None) _freq = _returned_kwargs.pop("freq", None) if _tz is not None: if _tz_localize is None: _tz_localize = _tz if _tz_convert is None: _tz_convert = _tz if _tz_localize is not None: if common_tz_localize is None: common_tz_localize = _tz_localize elif common_tz_localize != _tz_localize: raise ValueError("Returned objects have different timezones (tz_localize)") if _tz_convert is not None: if common_tz_convert is None: common_tz_convert = _tz_convert elif common_tz_convert != _tz_convert: if not silence_warnings: warn(f"Returned objects have different timezones (tz_convert). Setting to UTC.") common_tz_convert = "utc" if _freq is not None: if common_freq is None: common_freq = _freq elif common_freq != _freq: raise ValueError("Returned objects have different frequencies (freq)") if _data.size == 0: if not silence_warnings: if keys_are_features: warn(f"Feature '{str(k)}' returned an empty array. Skipping.") else: warn(f"Symbol '{str(k)}' returned an empty array. Skipping.") else: data[k] = _data returned_kwargs[k] = _returned_kwargs if tz_localize is None and common_tz_localize is not None: tz_localize = common_tz_localize if tz_convert is None and common_tz_convert is not None: tz_convert = common_tz_convert if wrapper_kwargs.get("freq", None) is None and common_freq is not None: wrapper_kwargs["freq"] = common_freq if len(data) == 0: if keys_are_features: raise ValueError("No features could be fetched") else: raise ValueError("No symbols could be fetched") return cls.from_data( data, single_key=single_key, classes=classes, level_name=level_name, tz_localize=tz_localize, tz_convert=tz_convert, missing_index=missing_index, missing_columns=missing_columns, wrapper_kwargs=wrapper_kwargs, fetch_kwargs=fetch_kwargs, returned_kwargs=returned_kwargs, silence_warnings=silence_warnings, ) @classmethod def download(cls: tp.Type[DataT], *args, **kwargs) -> tp.Union[DataT, tp.List[tp.Any]]: """Exists for backward compatibility. Use `Data.pull` instead.""" return cls.pull(*args, **kwargs) @classmethod def fetch(cls: tp.Type[DataT], *args, **kwargs) -> tp.Union[DataT, tp.List[tp.Any]]: """Exists for backward compatibility. Use `Data.pull` instead.""" return cls.pull(*args, **kwargs) @classmethod def from_data_str(cls: tp.Type[DataT], data_str: str) -> DataT: """Parse a `Data` instance from a string. For example: `YFData:BTC-USD` or just `BTC-USD` where the data class is `vectorbtpro.data.custom.yf.YFData` by default.""" from vectorbtpro.data import custom if ":" in data_str: cls_name, symbol = data_str.split(":") cls_name = cls_name.strip() symbol = symbol.strip() return getattr(custom, cls_name).pull(symbol) return custom.YFData.pull(data_str.strip()) # ############# Updating ############# # def update_feature( self, feature: tp.Feature, **kwargs, ) -> tp.FeatureData: """Update a feature. Can also return a dictionary that will be accessible in `Data.returned_kwargs`. This is an abstract method - override it to define custom logic.""" raise NotImplementedError def try_update_feature( self, feature: tp.Feature, skip_on_error: bool = False, silence_warnings: bool = False, update_kwargs: tp.KwargsLike = None, ) -> tp.FeatureData: """Try to update a feature using `Data.update_feature`.""" if update_kwargs is None: update_kwargs = {} try: out = self.update_feature(feature, **update_kwargs) if out is None: if not silence_warnings: warn(f"Feature '{str(feature)}' returned None. Skipping.") return out except Exception as e: if not skip_on_error: raise e if not silence_warnings: warn(traceback.format_exc()) warn(f"Feature '{str(feature)}' raised an exception. Skipping.") return None def update_symbol( self, symbol: tp.Symbol, **kwargs, ) -> tp.SymbolData: """Update a symbol. Can also return a dictionary that will be accessible in `Data.returned_kwargs`. This is an abstract method - override it to define custom logic.""" raise NotImplementedError def try_update_symbol( self, symbol: tp.Symbol, skip_on_error: bool = False, silence_warnings: bool = False, update_kwargs: tp.KwargsLike = None, ) -> tp.SymbolData: """Try to update a symbol using `Data.update_symbol`.""" if update_kwargs is None: update_kwargs = {} try: out = self.update_symbol(symbol, **update_kwargs) if out is None: if not silence_warnings: warn(f"Symbol '{str(symbol)}' returned None. Skipping.") return out except Exception as e: if not skip_on_error: raise e if not silence_warnings: warn(traceback.format_exc()) warn(f"Symbol '{str(symbol)}' raised an exception. Skipping.") return None def update( self: DataT, *, concat: bool = True, skip_on_error: tp.Optional[bool] = None, silence_warnings: tp.Optional[bool] = None, execute_kwargs: tp.KwargsLike = None, return_raw: bool = False, **kwargs, ) -> tp.Union[DataT, tp.List[tp.Any]]: """Update data. Fetches new data for each feature/symbol using `Data.update_feature`/`Data.update_symbol`. Args: concat (bool): Whether to concatenate existing and updated/new data. skip_on_error (bool): Whether to skip the feature/symbol when an exception is raised. silence_warnings (bool): Whether to silence all warnings. Will also forward this argument to `Data.update_feature`/`Data.update_symbol` if accepted by `Data.fetch_feature`/`Data.fetch_symbol`. execute_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.execution.execute`. return_raw (bool): Whether to return the raw outputs. **kwargs: Passed to `Data.update_feature`/`Data.update_symbol`. If two features/symbols require different keyword arguments, pass `key_dict` or `feature_dict`/`symbol_dict` for each argument. !!! note Returns a new `Data` instance instead of changing the data in place. """ skip_on_error = self.resolve_base_setting(skip_on_error, "skip_on_error") silence_warnings = self.resolve_base_setting(silence_warnings, "silence_warnings") execute_kwargs = self.resolve_base_setting(execute_kwargs, "execute_kwargs", merge=True) execute_kwargs = merge_dicts(dict(show_progress=False), execute_kwargs) if self.feature_oriented: func_arg_names = get_func_arg_names(self.fetch_feature) else: func_arg_names = get_func_arg_names(self.fetch_symbol) if "show_progress" in func_arg_names and "show_progress" not in kwargs: kwargs["show_progress"] = False checks.assert_instance_of(self.last_index, self.dict_type, "last_index") checks.assert_instance_of(self.delisted, self.dict_type, "delisted") tasks = [] key_indices = [] for i, k in enumerate(self.keys): if not self.delisted.get(k, False): if self.feature_oriented: key_update_func = self.try_update_feature key_update_kwargs = self.select_feature_kwargs(k, kwargs) else: key_update_func = self.try_update_symbol key_update_kwargs = self.select_symbol_kwargs(k, kwargs) if "silence_warnings" in func_arg_names: key_update_kwargs["silence_warnings"] = silence_warnings tasks.append( Task( key_update_func, k, skip_on_error=skip_on_error, silence_warnings=silence_warnings, update_kwargs=key_update_kwargs, ) ) key_indices.append(i) outputs = execute(tasks, size=len(self.keys), keys=self.key_index, **execute_kwargs) if return_raw: return outputs new_data = {} new_last_index = {} new_returned_kwargs = {} i = 0 for k, obj in self.data.items(): if self.delisted.get(k, False): out = None else: out = outputs[i] i += 1 skip_key = False if out is not None: if isinstance(out, tuple): new_obj = out[0] new_returned_kwargs[k] = out[1] else: new_obj = out new_returned_kwargs[k] = {} new_obj = to_any_array(new_obj) if new_obj.size == 0: if not silence_warnings: if self.feature_oriented: warn(f"Feature '{str(k)}' returned an empty array. Skipping.") else: warn(f"Symbol '{str(k)}' returned an empty array. Skipping.") skip_key = True else: if not isinstance(new_obj, (pd.Series, pd.DataFrame)): new_obj = to_pd_array(new_obj) new_obj.index = pd.RangeIndex( start=obj.index[-1], stop=obj.index[-1] + new_obj.shape[0], step=1, ) new_obj = self.prepare_tzaware_index( new_obj, tz_localize=self.tz_localize, tz_convert=self.tz_convert, ) if new_obj.index.is_monotonic_decreasing: new_obj = new_obj.iloc[::-1] elif not new_obj.index.is_monotonic_increasing: new_obj = new_obj.sort_index() if new_obj.index.has_duplicates: new_obj = new_obj[~new_obj.index.duplicated(keep="last")] new_data[k] = new_obj if len(new_obj.index) > 0: new_last_index[k] = new_obj.index[-1] else: new_last_index[k] = self.last_index[k] else: skip_key = True if skip_key: new_data[k] = obj.iloc[0:0] new_last_index[k] = self.last_index[k] # Get the last index in the old data from where the new data should begin from_index = None for k, new_obj in new_data.items(): if len(new_obj.index) > 0: index = new_obj.index[0] else: continue if from_index is None or index < from_index: from_index = index if from_index is None: if not silence_warnings: if self.feature_oriented: warn(f"None of the features were updated") else: warn(f"None of the symbols were updated") return self.copy() # Concatenate the updated old data and the new data for k, new_obj in new_data.items(): if len(new_obj.index) > 0: to_index = new_obj.index[0] else: to_index = None obj = self.data[k] if isinstance(obj, pd.DataFrame) and isinstance(new_obj, pd.DataFrame): shared_columns = obj.columns.intersection(new_obj.columns) obj = obj[shared_columns] new_obj = new_obj[shared_columns] elif isinstance(new_obj, pd.DataFrame): if obj.name is not None: new_obj = new_obj[obj.name] else: new_obj = new_obj[0] elif isinstance(obj, pd.DataFrame): if new_obj.name is not None: obj = obj[new_obj.name] else: obj = obj[0] obj = obj.loc[from_index:to_index] new_obj = pd.concat((obj, new_obj), axis=0) if new_obj.index.has_duplicates: new_obj = new_obj[~new_obj.index.duplicated(keep="last")] new_data[k] = new_obj # Align the index and columns in the new data new_data = self.align_index(new_data, missing=self.missing_index, silence_warnings=silence_warnings) new_data = self.align_columns(new_data, missing=self.missing_columns, silence_warnings=silence_warnings) # Align the columns and data type in the old and new data for k, new_obj in new_data.items(): obj = self.data[k] if isinstance(obj, pd.DataFrame) and isinstance(new_obj, pd.DataFrame): new_obj = new_obj[obj.columns] elif isinstance(new_obj, pd.DataFrame): if obj.name is not None: new_obj = new_obj[obj.name] else: new_obj = new_obj[0] if isinstance(obj, pd.DataFrame): new_obj = new_obj.astype(obj.dtypes, errors="ignore") else: new_obj = new_obj.astype(obj.dtype, errors="ignore") new_data[k] = new_obj if not concat: # Do not concatenate with the old data for k, new_obj in new_data.items(): if isinstance(new_obj.index, pd.DatetimeIndex): new_obj.index.freq = new_obj.index.inferred_freq new_index = new_data[self.keys[0]].index return self.replace( wrapper=self.wrapper.replace(index=new_index), data=self.dict_type(new_data), returned_kwargs=self.dict_type(new_returned_kwargs), last_index=self.dict_type(new_last_index), ) # Append the new data to the old data for k, new_obj in new_data.items(): obj = self.data[k] obj = obj.loc[:from_index] if obj.index[-1] == from_index: obj = obj.iloc[:-1] new_obj = pd.concat((obj, new_obj), axis=0) if isinstance(new_obj.index, pd.DatetimeIndex): new_obj.index.freq = new_obj.index.inferred_freq new_data[k] = new_obj new_index = new_data[self.keys[0]].index return self.replace( wrapper=self.wrapper.replace(index=new_index), data=self.dict_type(new_data), returned_kwargs=self.dict_type(new_returned_kwargs), last_index=self.dict_type(new_last_index), ) # ############# Transforming ############# # def transform( self: DataT, transform_func: tp.Callable, *args, per_feature: bool = False, per_symbol: bool = False, pass_frame: bool = False, key_wrapper_kwargs: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> DataT: """Transform data. If one key (i.e., feature or symbol), passes the entire Series/DataFrame. If `per_feature` is True, passes the Series/DataFrame of each feature. If `per_symbol` is True, passes the Series/DataFrame of each symbol. If both are True, passes each feature and symbol combination as a Series if `pass_frame` is False or as a DataFrame with one column if `pass_frame` is True. If both are False, concatenates all features and symbols into a single DataFrame and calls `transform_func` on it. Then, splits the data by key and builds a new `Data` instance. Keyword arguments `key_wrapper_kwargs` are passed to `Data.get_key_wrapper` to control, for example, attachment of classes. After the transformation, the new data is aligned using `Data.align_data`. If the data is not a Pandas object, it's broadcasted to the original data with `broadcast_kwargs`. !!! note The returned object must have the same type and dimensionality as the input object. Number of columns (i.e., features and symbols) and their names must stay the same. To remove columns, use either indexing or `Data.select` (depending on the data orientation). To add new columns, use either column stacking or `Data.merge`. Index, on the other hand, can be changed freely.""" if key_wrapper_kwargs is None: key_wrapper_kwargs = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} def _transform(data, _template_context=None): _transform_func = substitute_templates(transform_func, _template_context, eval_id="transform_func") _args = substitute_templates(args, _template_context, eval_id="args") _kwargs = substitute_templates(kwargs, _template_context, eval_id="kwargs") out = _transform_func(data, *_args, **_kwargs) if not isinstance(out, (pd.Series, pd.DataFrame)): out = broadcast_to(out, data, **broadcast_kwargs) return out if (self.feature_oriented and (per_feature and not per_symbol)) or ( self.symbol_oriented and (per_symbol and not per_feature) ): new_data = self.dict_type() for k in self.keys: if self.feature_oriented: _template_context = merge_dicts(dict(key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(key=k, symbol=k), template_context) new_data[k] = _transform(self.data[k], _template_context) checks.assert_meta_equal(new_data[k], self.data[k], axis=1) elif (self.feature_oriented and (per_symbol and not per_feature)) or ( self.symbol_oriented and (per_feature and not per_symbol) ): first_data = self.data[list(self.data.keys())[0]] if isinstance(first_data, pd.Series): concat_data = pd.concat(self.data.values(), axis=1) key_wrapper = self.get_key_wrapper(**key_wrapper_kwargs) concat_data.columns = key_wrapper.columns if self.feature_oriented: _template_context = merge_dicts( dict(column=self.wrapper.columns[0], symbol=self.wrapper.columns[0]), template_context, ) else: _template_context = merge_dicts( dict(column=self.wrapper.columns[0], feature=self.wrapper.columns[0]), template_context, ) new_concat_data = _transform(concat_data, _template_context) checks.assert_meta_equal(new_concat_data, concat_data, axis=1) new_data = self.dict_type() for i, k in enumerate(self.keys): new_data[k] = new_concat_data.iloc[:, i] new_data[k].name = first_data.name else: all_concat_data = [] for i in range(len(first_data.columns)): concat_data = pd.concat([self.data[k].iloc[:, [i]] for k in self.keys], axis=1) key_wrapper = self.get_key_wrapper(**key_wrapper_kwargs) concat_data.columns = key_wrapper.columns if self.feature_oriented: _template_context = merge_dicts( dict(column=self.wrapper.columns[i], symbol=self.wrapper.columns[i]), template_context, ) else: _template_context = merge_dicts( dict(column=self.wrapper.columns[i], feature=self.wrapper.columns[i]), template_context, ) new_concat_data = _transform(concat_data, _template_context) checks.assert_meta_equal(new_concat_data, concat_data, axis=1) all_concat_data.append(new_concat_data) new_data = self.dict_type() for i, k in enumerate(self.keys): new_objs = [] for c in range(len(first_data.columns)): new_objs.append(all_concat_data[c].iloc[:, [i]]) new_data[k] = pd.concat(new_objs, axis=1) new_data[k].columns = first_data.columns else: key_wrapper = self.get_key_wrapper(**key_wrapper_kwargs) concat_data = pd.concat(self.data.values(), axis=1, keys=key_wrapper.columns) if (self.feature_oriented and (per_feature and per_symbol)) or ( self.symbol_oriented and (per_symbol and per_feature) ): new_concat_data = [] for i in range(len(concat_data.columns)): if self.feature_oriented: _template_context = merge_dicts( dict( key=self.keys[i // len(self.wrapper.columns)], column=self.wrapper.columns[i % len(self.wrapper.columns)], feature=self.keys[i // len(self.wrapper.columns)], symbol=self.wrapper.columns[i % len(self.wrapper.columns)], ), template_context, ) else: _template_context = merge_dicts( dict( key=self.wrapper.columns[i % len(self.wrapper.columns)], column=self.keys[i // len(self.wrapper.columns)], feature=self.wrapper.columns[i % len(self.wrapper.columns)], symbol=self.keys[i // len(self.wrapper.columns)], ), template_context, ) if pass_frame: new_obj = _transform(concat_data.iloc[:, [i]], _template_context) checks.assert_meta_equal(new_obj, concat_data.iloc[:, [i]], axis=1) else: new_obj = _transform(concat_data.iloc[:, i], _template_context) checks.assert_meta_equal(new_obj, concat_data.iloc[:, i], axis=1) new_concat_data.append(new_obj) new_concat_data = pd.concat(new_concat_data, axis=1) else: new_concat_data = _transform(concat_data) checks.assert_meta_equal(new_concat_data, concat_data, axis=1) native_concat_data = pd.concat(self.data.values(), axis=1, keys=None) new_concat_data.columns = native_concat_data.columns new_data = self.dict_type() first_data = self.data[list(self.data.keys())[0]] for i, k in enumerate(self.keys): if isinstance(first_data, pd.Series): new_data[k] = new_concat_data.iloc[:, i] new_data[k].name = first_data.name else: start_i = first_data.shape[1] * i stop_i = first_data.shape[1] * (1 + i) new_data[k] = new_concat_data.iloc[:, start_i:stop_i] new_data[k].columns = first_data.columns new_data = self.align_data(new_data) first_data = new_data[list(new_data.keys())[0]] new_wrapper = self.wrapper.replace(index=first_data.index) return self.replace( wrapper=new_wrapper, data=new_data, ) def dropna(self: DataT, **kwargs) -> DataT: """Drop missing values. Keyword arguments are passed to `Data.transform` and then to `pd.Series.dropna` or `pd.DataFrame.dropna`.""" def _dropna(df, **_kwargs): return df.dropna(**_kwargs) return self.transform(_dropna, **kwargs) def mirror_ohlc( self: DataT, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, start_value: tp.ArrayLike = np.nan, ref_feature: tp.ArrayLike = -1, ) -> DataT: """Mirror OHLC features.""" if isinstance(ref_feature, str): ref_feature = map_enum_fields(ref_feature, PriceFeature) open_name = self.resolve_feature("Open") high_name = self.resolve_feature("High") low_name = self.resolve_feature("Low") close_name = self.resolve_feature("Close") func = jit_reg.resolve_option(mirror_ohlc_nb, jitted) func = ch_reg.resolve_option(func, chunked) new_open, new_high, new_low, new_close = func( self.symbol_wrapper.shape_2d, open=to_2d_array(self.get_feature(open_name)) if open_name is not None else None, high=to_2d_array(self.get_feature(high_name)) if high_name is not None else None, low=to_2d_array(self.get_feature(low_name)) if low_name is not None else None, close=to_2d_array(self.get_feature(close_name)) if close_name is not None else None, start_value=to_1d_array(start_value), ref_feature=to_1d_array(ref_feature), ) def _mirror_ohlc(df, feature, **_kwargs): if open_name is not None and feature == open_name: return new_open if high_name is not None and feature == high_name: return new_high if low_name is not None and feature == low_name: return new_low if close_name is not None and feature == close_name: return new_close return df return self.transform(_mirror_ohlc, Rep("feature"), per_feature=True) def resample(self: DataT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> DataT: """Perform resampling on `Data`. Features "open", "high", "low", "close", "volume", "trade count", and "vwap" (case-insensitive) are recognized and resampled automatically. Looks for `resample_func` of each feature in `Data.feature_config`. The function must accept the `Data` instance, object, and resampler.""" if wrapper_meta is None: wrapper_meta = self.wrapper.resample_meta(*args, **kwargs) def _resample_feature(obj, feature, symbol=None): resample_func = self.feature_config.get(feature, {}).get("resample_func", None) if resample_func is not None: if isinstance(resample_func, str): return obj.vbt.resample_apply(wrapper_meta["resampler"], resample_func) return resample_func(self, obj, wrapper_meta["resampler"]) if isinstance(feature, str) and feature.lower() == "open": return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.first_reduce_nb) if isinstance(feature, str) and feature.lower() == "high": return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.max_reduce_nb) if isinstance(feature, str) and feature.lower() == "low": return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.min_reduce_nb) if isinstance(feature, str) and feature.lower() == "close": return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.last_reduce_nb) if isinstance(feature, str) and feature.lower() == "volume": return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.sum_reduce_nb) if isinstance(feature, str) and feature.lower() == "trade count": return obj.vbt.resample_apply( wrapper_meta["resampler"], generic_nb.sum_reduce_nb, wrap_kwargs=dict(dtype=int), ) if isinstance(feature, str) and feature.lower() == "vwap": volume_obj = None for feature2 in self.features: if isinstance(feature2, str) and feature2.lower() == "volume": if self.feature_oriented: volume_obj = self.data[feature2] else: volume_obj = self.data[symbol][feature2] if volume_obj is None: raise ValueError("Volume is required to resample VWAP") return pd.DataFrame.vbt.resample_apply( wrapper_meta["resampler"], generic_nb.wmean_range_reduce_meta_nb, to_2d_array(obj), to_2d_array(volume_obj), wrapper=self.wrapper[feature], ) raise ValueError(f"Cannot resample feature '{feature}'. Specify resample_func in feature_config.") new_data = self.dict_type() if self.feature_oriented: for feature in self.features: new_data[feature] = _resample_feature(self.data[feature], feature) else: for symbol, obj in self.data.items(): _new_obj = [] for feature in self.features: if self.single_feature: _new_obj.append(_resample_feature(obj, feature, symbol=symbol)) else: _new_obj.append(_resample_feature(obj[[feature]], feature, symbol=symbol)) if self.single_feature: new_obj = _new_obj[0] else: new_obj = pd.concat(_new_obj, axis=1) new_data[symbol] = new_obj return self.replace( wrapper=wrapper_meta["new_wrapper"], data=new_data, ) def realign( self: DataT, rule: tp.Optional[tp.AnyRuleLike] = None, *args, wrapper_meta: tp.DictLike = None, ffill: bool = True, **kwargs, ) -> DataT: """Perform realigning on `Data`. Looks for `realign_func` of each feature in `Data.feature_config`. If no function provided, resamples feature "open" with `vectorbtpro.generic.accessors.GenericAccessor.realign_opening` and other features with `vectorbtpro.generic.accessors.GenericAccessor.realign_closing`.""" if rule is None: rule = self.wrapper.freq if wrapper_meta is None: wrapper_meta = self.wrapper.resample_meta(rule, *args, **kwargs) def _realign_feature(obj, feature, symbol=None): realign_func = self.feature_config.get(feature, {}).get("realign_func", None) if realign_func is not None: if isinstance(realign_func, str): return getattr(obj.vbt, realign_func)(wrapper_meta["resampler"], ffill=ffill) return realign_func(self, obj, wrapper_meta["resampler"], ffill=ffill) if isinstance(feature, str) and feature.lower() == "open": return obj.vbt.realign_opening(wrapper_meta["resampler"], ffill=ffill) return obj.vbt.realign_closing(wrapper_meta["resampler"], ffill=ffill) new_data = self.dict_type() if self.feature_oriented: for feature in self.features: new_data[feature] = _realign_feature(self.data[feature], feature) else: for symbol, obj in self.data.items(): _new_obj = [] for feature in self.features: if self.single_feature: _new_obj.append(_realign_feature(obj, feature, symbol=symbol)) else: _new_obj.append(_realign_feature(obj[[feature]], feature, symbol=symbol)) if self.single_feature: new_obj = _new_obj[0] else: new_obj = pd.concat(_new_obj, axis=1) new_data[symbol] = new_obj return self.replace( wrapper=wrapper_meta["new_wrapper"], data=new_data, ) # ############# Running ############# # @classmethod def try_run( cls, data: "Data", func_name: str, *args, raise_errors: bool = False, silence_warnings: bool = False, **kwargs, ) -> tp.Any: """Try to run a function on data.""" try: return data.run(*args, **kwargs) except Exception as e: if raise_errors: raise e if not silence_warnings: warn(func_name + ": " + str(e)) return NoResult @classmethod def select_run_func_args(cls, i: int, func_name: str, args: tp.Args) -> tuple: """Select positional arguments that correspond to a runnable function index or name.""" _args = () for v in args: if isinstance(v, run_func_dict): if func_name in v: _args += (v[func_name],) elif i in v: _args += (v[i],) elif "_def" in v: _args += (v["_def"],) else: _args += (v,) return _args @classmethod def select_run_func_kwargs(cls, i: int, func_name: str, kwargs: tp.Kwargs) -> dict: """Select keyword arguments that correspond to a runnable function index or name.""" _kwargs = {} for k, v in kwargs.items(): if isinstance(v, run_func_dict): if func_name in v: _kwargs[k] = v[func_name] elif i in v: _kwargs[k] = v[i] elif "_def" in v: _kwargs[k] = v["_def"] elif isinstance(v, run_arg_dict): if func_name == k or i == k: _kwargs.update(v) else: _kwargs[k] = v return _kwargs def run( self, func: tp.MaybeIterable[tp.Union[tp.Hashable, tp.Callable]], *args, on_features: tp.Optional[tp.MaybeFeatures] = None, on_symbols: tp.Optional[tp.MaybeSymbols] = None, func_args: tp.ArgsLike = None, func_kwargs: tp.KwargsLike = None, magnet_kwargs: tp.KwargsLike = None, ignore_args: tp.Optional[tp.Sequence[str]] = None, rename_args: tp.DictLike = None, location: tp.Optional[str] = None, prepend_location: tp.Optional[bool] = None, unpack: tp.Union[bool, str] = False, concat: bool = True, data_kwargs: tp.KwargsLike = None, silence_warnings: bool = False, raise_errors: bool = False, execute_kwargs: tp.KwargsLike = None, filter_results: bool = True, raise_no_results: bool = True, merge_func: tp.MergeFuncLike = None, merge_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, return_keys: bool = False, _func_name: tp.Optional[str] = None, **kwargs, ) -> tp.Any: """Run a function on data. Looks into the signature of the function and searches for arguments with the name `data` or those found among features or attributes. For example, the argument `open` will be substituted by `Data.open`. `func` can be one of the following: * Location to compute all indicators from. See `vectorbtpro.indicators.factory.IndicatorFactory.list_locations`. * Indicator name. See `vectorbtpro.indicators.factory.IndicatorFactory.get_indicator`. * Simulation method. See `vectorbtpro.portfolio.base.Portfolio`. * Any callable object * Iterable with any of the above. Will be stacked as columns into a DataFrame. Use `magnet_kwargs` to provide keyword arguments that will be passed only if found in the signature of the function. Use `rename_args` to rename arguments. For example, in `vectorbtpro.portfolio.base.Portfolio`, data can be passed instead of `close`. Set `unpack` to True, "dict", or "frame" to use `vectorbtpro.indicators.factory.IndicatorBase.unpack`, `vectorbtpro.indicators.factory.IndicatorBase.to_dict`, and `vectorbtpro.indicators.factory.IndicatorBase.to_frame` respectively. Any argument in `*args` and `**kwargs` can be wrapped with `run_func_dict`/`run_arg_dict` to specify the value per function/argument name or index when `func` is iterable. Multiple function calls are executed with `vectorbtpro.utils.execution.execute`.""" from vectorbtpro.indicators.factory import IndicatorBase, IndicatorFactory from vectorbtpro.indicators.talib_ import talib_func from vectorbtpro.portfolio.base import Portfolio if magnet_kwargs is None: magnet_kwargs = {} if data_kwargs is None: data_kwargs = {} if execute_kwargs is None: execute_kwargs = {} if merge_kwargs is None: merge_kwargs = {} _self = self if on_features is not None: _self = _self.select_features(on_features) if on_symbols is not None: _self = _self.select_symbols(on_symbols) if checks.is_complex_iterable(func): tasks = [] keys = [] for i, f in enumerate(func): _location = location if callable(f): func_name = f.__name__ elif isinstance(f, str): if _location is not None: func_name = f.lower().strip() if func_name == "*": func_name = "all" if prepend_location is True: func_name = _location + "_" + func_name else: _location, f = IndicatorFactory.split_indicator_name(f) if f is None: raise ValueError("Sequence of locations is not supported") func_name = f.lower().strip() if func_name == "*": func_name = "all" if _location is not None: if prepend_location in (None, True): func_name = _location + "_" + func_name else: func_name = f new_args = _self.select_run_func_args(i, func_name, args) new_args = (_self, func_name, f, *new_args) new_kwargs = _self.select_run_func_kwargs(i, func_name, kwargs) if concat and _location == "talib_func": new_kwargs["unpack_to"] = "frame" new_kwargs = { **dict( func_args=func_args, func_kwargs=func_kwargs, magnet_kwargs=magnet_kwargs, ignore_args=ignore_args, rename_args=rename_args, location=_location, prepend_location=prepend_location, unpack="frame" if concat else unpack, concat=concat, data_kwargs=data_kwargs, silence_warnings=silence_warnings, raise_errors=raise_errors, execute_kwargs=execute_kwargs, merge_func=merge_func, merge_kwargs=merge_kwargs, template_context=template_context, return_keys=return_keys, _func_name=func_name, ), **new_kwargs, } tasks.append(Task(self.try_run, *new_args, **new_kwargs)) keys.append(str(func_name)) keys = pd.Index(keys, name="run_func") results = execute(tasks, size=len(keys), keys=keys, **execute_kwargs) if filter_results: try: results, keys = filter_out_no_results(results, keys=keys) except NoResultsException as e: if raise_no_results: raise e return NoResult no_results_filtered = True else: no_results_filtered = False if merge_func is None and concat: merge_func = "column_stack" if merge_func is not None: if is_merge_func_from_config(merge_func): merge_kwargs = merge_dicts( dict( keys=keys, filter_results=not no_results_filtered, raise_no_results=raise_no_results, ), merge_kwargs, ) if isinstance(merge_func, MergeFunc): merge_func = merge_func.replace(merge_kwargs=merge_kwargs, context=template_context) else: merge_func = MergeFunc(merge_func, merge_kwargs=merge_kwargs, context=template_context) if return_keys: return merge_func(results), keys return merge_func(results) if return_keys: return results, keys return results if isinstance(func, str): func_name = func.lower().strip() if func_name.startswith("from_") and getattr(Portfolio, func_name): func = getattr(Portfolio, func_name) if func_args is None: func_args = () if func_kwargs is None: func_kwargs = {} pf = func(_self, *args, *func_args, **kwargs, **func_kwargs) if isinstance(pf, Portfolio) and unpack: raise ValueError("Portfolio cannot be unpacked") return pf if location is None: location, func_name = IndicatorFactory.split_indicator_name(func_name) if location is not None and (func_name is None or func_name == "all" or func_name == "*"): matched_location = IndicatorFactory.match_location(location) if matched_location is not None: location = matched_location if func_name == "all" or func_name == "*": if prepend_location is None: prepend_location = True else: if prepend_location is None: prepend_location = False if location == "talib_func": indicators = IndicatorFactory.list_indicators("talib", prepend_location=False) else: indicators = IndicatorFactory.list_indicators(location, prepend_location=False) return _self.run( indicators, *args, func_args=func_args, func_kwargs=func_kwargs, magnet_kwargs=magnet_kwargs, ignore_args=ignore_args, rename_args=rename_args, location=location, prepend_location=prepend_location, unpack=unpack, concat=concat, data_kwargs=data_kwargs, silence_warnings=silence_warnings, raise_errors=raise_errors, execute_kwargs=execute_kwargs, merge_func=merge_func, merge_kwargs=merge_kwargs, template_context=template_context, return_keys=return_keys, **kwargs, ) if location is not None: matched_location = IndicatorFactory.match_location(location) if matched_location is not None: location = matched_location if location == "talib_func": func = talib_func(func_name) else: func = IndicatorFactory.get_indicator(func_name, location=location) else: func = IndicatorFactory.get_indicator(func_name) if isinstance(func, type) and issubclass(func, IndicatorBase): func = func.run with_kwargs = {} func_arg_names = get_func_arg_names(func) for arg_name in func_arg_names: real_arg_name = arg_name if ignore_args is not None: if arg_name in ignore_args: continue if rename_args is not None: if arg_name in rename_args: arg_name = rename_args[arg_name] if real_arg_name not in kwargs: if arg_name == "data": with_kwargs[real_arg_name] = _self elif arg_name == "wrapper": with_kwargs[real_arg_name] = _self.symbol_wrapper elif arg_name in ("input_shape", "shape"): with_kwargs[real_arg_name] = _self.shape elif arg_name in ("target_shape", "shape_2d"): with_kwargs[real_arg_name] = _self.shape_2d elif arg_name in ("input_index", "index"): with_kwargs[real_arg_name] = _self.index elif arg_name in ("input_columns", "columns"): with_kwargs[real_arg_name] = _self.columns elif arg_name == "freq": with_kwargs[real_arg_name] = _self.freq elif arg_name == "hlc3": with_kwargs[real_arg_name] = _self.hlc3 elif arg_name == "ohlc4": with_kwargs[real_arg_name] = _self.ohlc4 elif arg_name == "returns": with_kwargs[real_arg_name] = _self.returns else: feature_idx = _self.get_feature_idx(arg_name) if feature_idx != -1: with_kwargs[real_arg_name] = _self.get_feature(feature_idx) kwargs = dict(kwargs) for k, v in magnet_kwargs.items(): if k in func_arg_names: kwargs[k] = v new_args, new_kwargs = extend_args(func, args, kwargs, **with_kwargs) if func_args is None: func_args = () if func_kwargs is None: func_kwargs = {} out = func(*new_args, *func_args, **new_kwargs, **func_kwargs) if isinstance(unpack, bool): if unpack: if isinstance(out, IndicatorBase): out = out.unpack() elif isinstance(unpack, str) and unpack.lower() == "dict": if isinstance(out, IndicatorBase): out = out.to_dict() else: if _func_name is None: feature_name = func.__name__ else: feature_name = _func_name out = {feature_name: out} elif isinstance(unpack, str) and unpack.lower() == "frame": if isinstance(out, IndicatorBase): out = out.to_frame() elif isinstance(out, pd.Series): out = out.to_frame() elif isinstance(unpack, str) and unpack.lower() == "data": if isinstance(out, IndicatorBase): out = feature_dict(out.to_dict()) else: if _func_name is None: feature_name = func.__name__ else: feature_name = _func_name out = feature_dict({feature_name: out}) out = Data.from_data(out, **data_kwargs) else: raise ValueError(f"Invalid unpack: '{unpack}'") return out # ############# Persisting ############# # def resolve_key_arg( self, arg: tp.Any, k: tp.Key, arg_name: str, check_dict_type: bool = True, template_context: tp.KwargsLike = None, is_kwargs: bool = False, ) -> tp.Any: """Resolve argument.""" if check_dict_type: self.check_dict_type(arg, arg_name=arg_name) if isinstance(arg, key_dict): _arg = arg[k] else: if is_kwargs: _arg = self.select_key_kwargs(k, arg, check_dict_type=check_dict_type) else: _arg = arg if isinstance(_arg, CustomTemplate): _arg = _arg.substitute(template_context, eval_id=arg_name) elif is_kwargs: _arg = substitute_templates(_arg, template_context, eval_id=arg_name) return _arg def to_csv( self, path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".", ext: tp.Union[str, feature_dict, symbol_dict, CustomTemplate] = "csv", mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None, check_dict_type: bool = True, template_context: tp.KwargsLike = None, return_meta: bool = False, **kwargs, ) -> tp.Union[None, feature_dict, symbol_dict]: """Save data to CSV file(s). Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`, depending on the format of the data dictionary. If `path_or_buf` is a path to a directory, will save each feature/symbol to a separate file. If there's only one file, you can specify the file path via `path_or_buf`. If there are multiple files, use the same argument but wrap the multiple paths with `key_dict`.""" meta = self.dict_type() for k, v in self.data.items(): if self.feature_oriented: _template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context) _kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type) sep = _kwargs.pop("sep", None) _path_or_buf = self.resolve_key_arg( path_or_buf, k, "path_or_buf", check_dict_type=check_dict_type, template_context=_template_context, ) if isinstance(_path_or_buf, str): _path_or_buf = Path(_path_or_buf) if isinstance(_path_or_buf, Path): if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "": _ext = self.resolve_key_arg( ext, k, "ext", check_dict_type=check_dict_type, template_context=_template_context, ) _path_or_buf /= f"{k}.{_ext}" _mkdir_kwargs = self.resolve_key_arg( mkdir_kwargs, k, "mkdir_kwargs", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) check_mkdir(_path_or_buf.parent, **_mkdir_kwargs) if _path_or_buf.suffix.lower() == ".csv": if sep is None: sep = "," if _path_or_buf.suffix.lower() == ".tsv": if sep is None: sep = "\t" _path_or_buf = str(_path_or_buf) if sep is None: sep = "," meta[k] = {"path_or_buf": _path_or_buf, "sep": sep, **_kwargs} v.to_csv(**meta[k]) if return_meta: return meta return None @classmethod def from_csv(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Use `vectorbtpro.data.custom.csv.CSVData` to load data from CSV and switch the class back to this class. Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching.""" from vectorbtpro.data.custom.csv import CSVData if fetch_kwargs is None: fetch_kwargs = {} data = CSVData.pull(*args, **kwargs) data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True) data = data.update_fetch_kwargs(**fetch_kwargs) return data def to_hdf( self, path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".", key: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None, format: str = "table", check_dict_type: bool = True, template_context: tp.KwargsLike = None, return_meta: bool = False, **kwargs, ) -> tp.Union[None, feature_dict, symbol_dict]: """Save data to an HDF file using PyTables. Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_hdf.html Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`, depending on the format of the data dictionary. If `path_or_buf` exists and it's a directory, will create inside it a file named after this class.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("tables") meta = self.dict_type() for k, v in self.data.items(): if self.feature_oriented: _template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context) _path_or_buf = self.resolve_key_arg( path_or_buf, k, "path_or_buf", check_dict_type=check_dict_type, template_context=_template_context, ) if isinstance(_path_or_buf, str): _path_or_buf = Path(_path_or_buf) if isinstance(_path_or_buf, Path): if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "": _path_or_buf /= type(self).__name__ + ".h5" _mkdir_kwargs = self.resolve_key_arg( mkdir_kwargs, k, "mkdir_kwargs", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) check_mkdir(_path_or_buf.parent, **_mkdir_kwargs) _path_or_buf = str(_path_or_buf) if key is None: _key = str(k) else: _key = self.resolve_key_arg( key, k, "key", check_dict_type=check_dict_type, template_context=_template_context, ) _kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type) meta[k] = {"path_or_buf": _path_or_buf, "key": _key, "format": format, **_kwargs} v.to_hdf(**meta[k]) if return_meta: return meta return None @classmethod def from_hdf(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Use `vectorbtpro.data.custom.hdf.HDFData` to load data from HDF and switch the class back to this class. Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching.""" from vectorbtpro.data.custom.hdf import HDFData if fetch_kwargs is None: fetch_kwargs = {} data = HDFData.pull(*args, **kwargs) data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True) data = data.update_fetch_kwargs(**fetch_kwargs) return data def to_feather( self, path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".", mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None, check_dict_type: bool = True, template_context: tp.KwargsLike = None, return_meta: bool = False, **kwargs, ) -> tp.Union[None, feature_dict, symbol_dict]: """Save data to Feather file(s) using PyArrow. Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_feather.html Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`, depending on the format of the data dictionary. If `path_or_buf` is a path to a directory, will save each feature/symbol to a separate file. If there's only one file, you can specify the file path via `path_or_buf`. If there are multiple files, use the same argument but wrap the multiple paths with `key_dict`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pyarrow") meta = self.dict_type() for k, v in self.data.items(): if self.feature_oriented: _template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context) _path_or_buf = self.resolve_key_arg( path_or_buf, k, "path_or_buf", check_dict_type=check_dict_type, template_context=_template_context, ) if isinstance(_path_or_buf, str): _path_or_buf = Path(_path_or_buf) if isinstance(_path_or_buf, Path): if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "": _path_or_buf /= f"{k}.feather" _mkdir_kwargs = self.resolve_key_arg( mkdir_kwargs, k, "mkdir_kwargs", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) check_mkdir(_path_or_buf.parent, **_mkdir_kwargs) _path_or_buf = str(_path_or_buf) _kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type) meta[k] = {"path": _path_or_buf, **_kwargs} if isinstance(v, pd.Series): v = v.to_frame() try: v.to_feather(**meta[k]) except Exception as e: if isinstance(e, ValueError) and "you can .reset_index()" in str(e): v = v.reset_index() v.to_feather(**meta[k]) else: raise e if return_meta: return meta return None @classmethod def from_feather(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Use `vectorbtpro.data.custom.feather.FeatherData` to load data from Feather and switch the class back to this class. Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching.""" from vectorbtpro.data.custom.feather import FeatherData if fetch_kwargs is None: fetch_kwargs = {} data = FeatherData.pull(*args, **kwargs) data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True) data = data.update_fetch_kwargs(**fetch_kwargs) return data def to_parquet( self, path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".", mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None, partition_cols: tp.Union[None, tp.List[str], feature_dict, symbol_dict, CustomTemplate] = None, partition_by: tp.Union[None, tp.AnyGroupByLike, feature_dict, symbol_dict, CustomTemplate] = None, period_index_to: tp.Union[str, tp.AnyGroupByLike, feature_dict, symbol_dict, CustomTemplate] = "str", groupby_kwargs: tp.Union[None, tp.AnyGroupByLike, feature_dict, symbol_dict, CustomTemplate] = None, keep_groupby_names: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = False, engine: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, check_dict_type: bool = True, template_context: tp.KwargsLike = None, return_meta: bool = False, **kwargs, ) -> tp.Union[None, feature_dict, symbol_dict]: """Save data to Parquet file(s) using PyArrow or FastParquet. Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_parquet.html Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`, depending on the format of the data dictionary. If `path_or_buf` is a path to a directory, will save each feature/symbol to a separate file. If there's only one file, you can specify the file path via `path_or_buf`. If there are multiple files, use the same argument but wrap the multiple paths with `key_dict`. If `partition_cols` and `partition_by` are None, `path_or_buf` must be a file, otherwise it must be a directory. If `partition_by` is not None, will group the index by using `vectorbtpro.base.wrapping.ArrayWrapper.get_index_grouper` with `**groupby_kwargs` and put it inside `partition_cols`. In this case, `partition_cols` must be None.""" from vectorbtpro.utils.module_ import assert_can_import, assert_can_import_any from vectorbtpro.data.custom.parquet import ParquetData meta = self.dict_type() for k, v in self.data.items(): if self.feature_oriented: _template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context) _partition_cols = self.resolve_key_arg( partition_cols, k, "partition_cols", check_dict_type=check_dict_type, template_context=_template_context, ) _partition_by = self.resolve_key_arg( partition_by, k, "partition_by", check_dict_type=check_dict_type, template_context=_template_context, ) if _partition_cols is not None and _partition_by is not None: raise ValueError("Must use either partition_cols or partition_by, not both") _path_or_buf = self.resolve_key_arg( path_or_buf, k, "path_or_buf", check_dict_type=check_dict_type, template_context=_template_context, ) if isinstance(_path_or_buf, str): _path_or_buf = Path(_path_or_buf) if isinstance(_path_or_buf, Path): if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "": if _partition_cols is not None or _partition_by is not None: _path_or_buf /= f"{k}" else: _path_or_buf /= f"{k}.parquet" _mkdir_kwargs = self.resolve_key_arg( mkdir_kwargs, k, "mkdir_kwargs", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) check_mkdir(_path_or_buf.parent, **_mkdir_kwargs) _path_or_buf = str(_path_or_buf) _engine = self.resolve_key_arg( ParquetData.resolve_custom_setting(engine, "engine"), k, "engine", check_dict_type=check_dict_type, template_context=_template_context, ) if _engine == "pyarrow": assert_can_import("pyarrow") elif _engine == "fastparquet": assert_can_import("fastparquet") elif _engine == "auto": assert_can_import_any("pyarrow", "fastparquet") else: raise ValueError(f"Invalid engine: '{_engine}'") if isinstance(v, pd.Series): v = v.to_frame() if _partition_by is not None: _period_index_to = self.resolve_key_arg( period_index_to, k, "period_index_to", check_dict_type=check_dict_type, template_context=_template_context, ) _groupby_kwargs = self.resolve_key_arg( groupby_kwargs, k, "groupby_kwargs", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) _keep_groupby_names = self.resolve_key_arg( keep_groupby_names, k, "keep_groupby_names", check_dict_type=check_dict_type, template_context=_template_context, ) v = v.copy(deep=False) grouper = self.wrapper.get_index_grouper(_partition_by, **_groupby_kwargs) partition_index = grouper.get_stretched_index() _partition_cols = [] def _convert_period_index(index): if _period_index_to == "str": return index.map(str) return index.to_timestamp(how=_period_index_to) if isinstance(partition_index, pd.MultiIndex): for i in range(partition_index.nlevels): partition_level = partition_index.get_level_values(i) if _keep_groupby_names: new_column_name = partition_level.name else: new_column_name = f"group_{i}" if isinstance(partition_level, pd.PeriodIndex): partition_level = _convert_period_index(partition_level) v[new_column_name] = partition_level _partition_cols.append(new_column_name) else: if _keep_groupby_names: new_column_name = partition_index.name else: new_column_name = "group" if isinstance(partition_index, pd.PeriodIndex): partition_index = _convert_period_index(partition_index) v[new_column_name] = partition_index _partition_cols.append(new_column_name) _kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type) meta[k] = {"path": _path_or_buf, "partition_cols": _partition_cols, "engine": _engine, **_kwargs} v.to_parquet(**meta[k]) if return_meta: return meta return None @classmethod def from_parquet(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Use `vectorbtpro.data.custom.parquet.ParquetData` to load data from Parquet and switch the class back to this class. Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching.""" from vectorbtpro.data.custom.parquet import ParquetData if fetch_kwargs is None: fetch_kwargs = {} data = ParquetData.pull(*args, **kwargs) data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True) data = data.update_fetch_kwargs(**fetch_kwargs) return data def to_sql( self, engine: tp.Union[None, str, EngineT, feature_dict, symbol_dict, CustomTemplate] = None, table: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, schema: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, to_utc: tp.Union[None, bool, str, tp.Sequence[str], feature_dict, symbol_dict, CustomTemplate] = None, remove_utc_tz: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = True, attach_row_number: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = False, from_row_number: tp.Union[None, int, feature_dict, symbol_dict, CustomTemplate] = None, row_number_column: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, engine_config: tp.KwargsLike = None, dispose_engine: tp.Optional[bool] = None, check_dict_type: bool = True, template_context: tp.KwargsLike = None, return_meta: bool = False, return_engine: bool = False, **kwargs, ) -> tp.Union[None, feature_dict, symbol_dict, EngineT]: """Save data to a SQL database using SQLAlchemy. Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`, depending on the format of the data dictionary. Each feature/symbol gets saved to a separate table. If `engine` is None or a string, will resolve an engine with `vectorbtpro.data.custom.sql.SQLData.resolve_engine` and dispose it afterward if `dispose_engine` is None or True. It can additionally return the engine if `return_engine` is True or entire metadata (all passed arguments as `feature_dict` or `symbol_dict`). In this case, the engine won't be disposed by default. If `schema` is not None and it doesn't exist, will create a new schema. For `to_utc` and `remove_utc_tz`, see `Data.prepare_dt`. If `to_utc` is None, uses the corresponding setting of `vectorbtpro.data.custom.sql.SQLData`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("sqlalchemy") from vectorbtpro.data.custom.sql import SQLData if engine_config is None: engine_config = {} if (engine is None or isinstance(engine, str)) and not self.has_key_dict(engine_config): engine_meta = SQLData.resolve_engine( engine=engine, return_meta=True, **engine_config, ) engine = engine_meta["engine"] engine_name = engine_meta["engine_name"] should_dispose = engine_meta["should_dispose"] if dispose_engine is None: if return_meta or return_engine: dispose_engine = False else: dispose_engine = should_dispose else: engine_name = None if return_engine: raise ValueError("Engine can be returned only if URL was provided") meta = self.dict_type() for k, v in self.data.items(): if self.feature_oriented: _template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context) _engine = self.resolve_key_arg( engine, k, "engine", check_dict_type=check_dict_type, template_context=_template_context, ) _engine_config = self.resolve_key_arg( engine_config, k, "engine_config", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) if _engine is None or isinstance(_engine, str): _engine_meta = SQLData.resolve_engine( engine=_engine, return_meta=True, **_engine_config, ) _engine = _engine_meta["engine"] _engine_name = _engine_meta["engine_name"] _should_dispose = _engine_meta["should_dispose"] if dispose_engine is None: if return_meta or return_engine: _dispose_engine = False else: _dispose_engine = _should_dispose else: _dispose_engine = dispose_engine else: _engine_name = engine_name if dispose_engine is None: _dispose_engine = False else: _dispose_engine = dispose_engine if table is None: _table = k else: _table = self.resolve_key_arg( table, k, "table", check_dict_type=check_dict_type, template_context=_template_context, ) _schema = self.resolve_key_arg( SQLData.resolve_engine_setting(schema, "schema", engine_name=_engine_name), k, "schema", check_dict_type=check_dict_type, template_context=_template_context, ) _to_utc = self.resolve_key_arg( SQLData.resolve_engine_setting(to_utc, "to_utc", engine_name=_engine_name), k, "to_utc", check_dict_type=check_dict_type, template_context=_template_context, ) _remove_utc_tz = self.resolve_key_arg( remove_utc_tz, k, "remove_utc_tz", check_dict_type=check_dict_type, template_context=_template_context, ) _attach_row_number = self.resolve_key_arg( attach_row_number, k, "attach_row_number", check_dict_type=check_dict_type, template_context=_template_context, ) _from_row_number = self.resolve_key_arg( from_row_number, k, "from_row_number", check_dict_type=check_dict_type, template_context=_template_context, ) _row_number_column = self.resolve_key_arg( SQLData.resolve_engine_setting(row_number_column, "row_number_column", engine_name=_engine_name), k, "row_number_column", check_dict_type=check_dict_type, template_context=_template_context, ) v = SQLData.prepare_dt(v, to_utc=_to_utc, remove_utc_tz=_remove_utc_tz, parse_dates=False) _kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type) if _attach_row_number: v = v.copy(deep=False) if isinstance(v, pd.Series): v = v.to_frame() if _from_row_number is None: if not SQLData.has_table(_table, schema=_schema, engine=_engine): _from_row_number = 0 elif _kwargs.get("if_exists", "fail") != "append": _from_row_number = 0 else: last_row_number = SQLData.get_last_row_number( _table, schema=_schema, row_number_column=_row_number_column, engine=_engine, ) _from_row_number = last_row_number + 1 v[_row_number_column] = np.arange(_from_row_number, _from_row_number + len(v.index)) if _schema is not None: SQLData.create_schema(_schema, engine=_engine) meta[k] = {"name": _table, "con": _engine, "schema": _schema, **_kwargs} v.to_sql(**meta[k]) if _dispose_engine: _engine.dispose() if return_meta: return meta if return_engine: return engine return None @classmethod def from_sql(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Use `vectorbtpro.data.custom.sql.SQLData` to load data from a SQL database and switch the class back to this class. Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching.""" from vectorbtpro.data.custom.sql import SQLData if fetch_kwargs is None: fetch_kwargs = {} data = SQLData.pull(*args, **kwargs) data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True) data = data.update_fetch_kwargs(**fetch_kwargs) return data def to_duckdb( self, connection: tp.Union[None, str, DuckDBPyConnectionT, feature_dict, symbol_dict, CustomTemplate] = None, table: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, schema: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, catalog: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, write_format: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None, write_path: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".", write_options: tp.Union[None, str, dict, feature_dict, symbol_dict, CustomTemplate] = None, mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None, to_utc: tp.Union[None, bool, str, tp.Sequence[str], feature_dict, symbol_dict, CustomTemplate] = None, remove_utc_tz: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = True, if_exists: tp.Union[str, feature_dict, symbol_dict, CustomTemplate] = "fail", connection_config: tp.KwargsLike = None, check_dict_type: bool = True, template_context: tp.KwargsLike = None, return_meta: bool = False, return_connection: bool = False, ) -> tp.Union[None, feature_dict, symbol_dict, DuckDBPyConnectionT]: """Save data to a DuckDB database. Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`, depending on the format of the data dictionary. If `connection` is None or a string, will resolve a connection with `vectorbtpro.data.custom.duckdb.DuckDBData.resolve_connection`. It can additionally return the connection if `return_connection` is True or entire metadata (all passed arguments as `feature_dict` or `symbol_dict`). In this case, the engine won't be disposed by default. If `write_format` is None and `write_path` is a directory (default), will persist each feature/symbol to a table (see https://duckdb.org/docs/guides/python/import_pandas). If `catalog` is not None, will make it default for this connection. If `schema` is not None, and it doesn't exist, will create a new schema in the current catalog and make it default for this connection. Any new table will be automatically created under this schema. If `if_exists` is "fail", will raise an error if a table with the same name already exists. If `if_exists` is "replace", will drop the existing table first. If `if_exists` is "append", will append the new table to the existing one. If `write_format` is not None, it must be either "csv", "parquet", or "json". If `write_path` is a directory or has no suffix (meaning it's not a file), each feature/symbol will be saved to a separate file under that path and with the provided `write_format` as extension. The data will be saved using a `COPY` mechanism (see https://duckdb.org/docs/sql/statements/copy.html). To provide options to the write operation, pass them as a dictionary or an already formatted string (without brackets). For example, `dict(compression="gzip")` is same as "COMPRESSION 'gzip'". For `to_utc` and `remove_utc_tz`, see `Data.prepare_dt`. If `to_utc` is None, uses the corresponding setting of `vectorbtpro.data.custom.duckdb.DuckDBData`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("duckdb") from vectorbtpro.data.custom.duckdb import DuckDBData if connection_config is None: connection_config = {} if (connection is None or isinstance(connection, (str, Path))) and not self.has_key_dict(connection_config): connection_meta = DuckDBData.resolve_connection( connection=connection, read_only=False, return_meta=True, **connection_config, ) connection = connection_meta["connection"] if return_meta or return_connection: should_close = False else: should_close = connection_meta["should_close"] elif return_connection: raise ValueError("Connection can be returned only if URL was provided") else: should_close = False meta = self.dict_type() for k, v in self.data.items(): if self.feature_oriented: _template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context) else: _template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context) _connection = self.resolve_key_arg( connection, k, "connection", check_dict_type=check_dict_type, template_context=_template_context, ) _connection_config = self.resolve_key_arg( connection_config, k, "connection_config", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) if _connection is None or isinstance(_connection, (str, Path)): _connection_meta = DuckDBData.resolve_connection( connection=_connection, read_only=False, return_meta=True, **_connection_config, ) _connection = _connection_meta["connection"] _should_close = _connection_meta["should_close"] else: _should_close = False if table is None: _table = k else: _table = self.resolve_key_arg( table, k, "table", check_dict_type=check_dict_type, template_context=_template_context, ) _schema = self.resolve_key_arg( DuckDBData.resolve_custom_setting(schema, "schema"), k, "schema", check_dict_type=check_dict_type, template_context=_template_context, ) _catalog = self.resolve_key_arg( DuckDBData.resolve_custom_setting(catalog, "catalog"), k, "catalog", check_dict_type=check_dict_type, template_context=_template_context, ) _write_format = self.resolve_key_arg( write_format, k, "write_format", check_dict_type=check_dict_type, template_context=_template_context, ) _write_path = self.resolve_key_arg( write_path, k, "write_path", check_dict_type=check_dict_type, template_context=_template_context, ) _write_path = Path(_write_path) is_not_file = (_write_path.exists() and _write_path.is_dir()) or _write_path.suffix == "" if _write_format is not None and is_not_file: if _write_format.upper() == "CSV": _write_path /= f"{k}.csv" elif _write_format.upper() == "PARQUET": _write_path /= f"{k}.parquet" elif _write_format.upper() == "JSON": _write_path /= f"{k}.json" else: raise ValueError(f"Invalid write format: '{_write_format}'") if _write_path.suffix != "": _mkdir_kwargs = self.resolve_key_arg( mkdir_kwargs, k, "mkdir_kwargs", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=True, ) check_mkdir(_write_path.parent, **_mkdir_kwargs) _write_path = str(_write_path) use_write = True else: use_write = False _to_utc = self.resolve_key_arg( DuckDBData.resolve_custom_setting(to_utc, "to_utc"), k, "to_utc", check_dict_type=check_dict_type, template_context=_template_context, ) _remove_utc_tz = self.resolve_key_arg( remove_utc_tz, k, "remove_utc_tz", check_dict_type=check_dict_type, template_context=_template_context, ) _if_exists = self.resolve_key_arg( if_exists, k, "if_exists", check_dict_type=check_dict_type, template_context=_template_context, ) v = DuckDBData.prepare_dt(v, to_utc=_to_utc, remove_utc_tz=_remove_utc_tz, parse_dates=False) v = v.reset_index() if use_write: _write_options = self.resolve_key_arg( write_options, k, "write_options", check_dict_type=check_dict_type, template_context=_template_context, is_kwargs=isinstance(write_options, dict), ) if _write_options is not None: _write_options = DuckDBData.format_write_options(_write_options) if _write_format is not None and _write_options is not None and "FORMAT" not in _write_options: _write_options = f"FORMAT {_write_format.upper()}, " + _write_options elif _write_format is not None and _write_options is None: _write_options = f"FORMAT {_write_format.upper()}" _connection.register("_" + k, v) if _write_options is not None: _connection.sql(f"COPY (SELECT * FROM \"_{k}\") TO '{_write_path}' ({_write_options})") else: _connection.sql(f"COPY (SELECT * FROM \"_{k}\") TO '{_write_path}'") meta[k] = {"write_path": _write_path, "write_options": _write_options} else: if _catalog is not None: _connection.sql(f"USE {_catalog}") elif _schema is not None: catalogs = DuckDBData.list_catalogs(connection=_connection) if len(catalogs) > 1: raise ValueError("Please select a catalog") _catalog = catalogs[0] _connection.sql(f"USE {_catalog}") if _schema is not None: _connection.sql(f'CREATE SCHEMA IF NOT EXISTS "{_schema}"') _connection.sql(f"USE {_catalog}.{_schema}") append = False if _table in DuckDBData.list_tables(catalog=_catalog, schema=_schema, connection=_connection): if _if_exists.lower() == "fail": raise ValueError(f"Table '{_table}' already exists") elif _if_exists.lower() == "replace": _connection.sql(f'DROP TABLE "{_table}"') elif _if_exists.lower() == "append": append = True _connection.register("_" + k, v) if append: _connection.sql(f'INSERT INTO "{_table}" SELECT * FROM "_{k}"') else: _connection.sql(f'CREATE TABLE "{_table}" AS SELECT * FROM "_{k}"') meta[k] = {"table": _table, "schema": _schema, "catalog": _catalog} if _should_close: _connection.close() if should_close: connection.close() if return_meta: return meta if return_connection: return connection return None @classmethod def from_duckdb(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT: """Use `vectorbtpro.data.custom.duckdb.DuckDBData` to load data from a DuckDB database and switch the class back to this class. Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching.""" from vectorbtpro.data.custom.duckdb import DuckDBData if fetch_kwargs is None: fetch_kwargs = {} data = DuckDBData.pull(*args, **kwargs) data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True) data = data.update_fetch_kwargs(**fetch_kwargs) return data # ############# Querying ############# # def sql( self, query: str, dbcon: tp.Optional[DuckDBPyConnectionT] = None, database: str = ":memory:", db_config: tp.KwargsLike = None, alias: str = "", params: tp.KwargsLike = None, other_objs: tp.Optional[dict] = None, date_as_object: bool = False, align_dtypes: bool = True, squeeze: bool = True, **kwargs, ) -> tp.SeriesFrame: """Run a SQL query on this instance using DuckDB. First, connection gets established. Then, `Data.get` gets invoked with `**kwargs` passed as keyword arguments and `as_dict=True`. Then, each returned object gets registered within the database. Finally, the query gets executed with `duckdb.sql` and the relation as a DataFrame gets returned. If `squeeze` is True, a DataFrame with one column will be converted into a Series.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("duckdb") from duckdb import connect if db_config is None: db_config = {} if dbcon is None: dbcon = connect(database=database, read_only=False, config=db_config) if params is None: params = {} dtypes = {} objs = self.get(**kwargs, as_dict=True) for k, v in objs.items(): if not checks.is_default_index(v.index): v = v.reset_index() if isinstance(v, pd.Series): v = v.to_frame() for c in v.columns: dtypes[c] = v[c].dtype dbcon.register(k, v) if other_objs is not None: checks.assert_instance_of(other_objs, dict, arg_name="other_objs") for k, v in other_objs.items(): if not checks.is_default_index(v.index): v = v.reset_index() if isinstance(v, pd.Series): v = v.to_frame() for c in v.columns: dtypes[c] = v[c].dtype dbcon.register(k, v) df = dbcon.sql(query, alias=alias, params=params).df(date_as_object=date_as_object) if align_dtypes: for c in df.columns: if c in dtypes: df[c] = df[c].astype(dtypes[c]) if isinstance(self.index, pd.MultiIndex): if set(self.index.names) <= set(df.columns): df = df.set_index(self.index.names) else: if self.index.name is not None and self.index.name in df.columns: df = df.set_index(self.index.name) elif "index" in df.columns: df = df.set_index("index") df.index.name = None if squeeze and len(df.columns) == 1: df = df.iloc[:, 0] return df # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Data.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.data`.""" return merge_dicts(Analyzable.stats_defaults.__get__(self), self.get_base_settings()["stats"]) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), total_features=dict( title="Total Features", check_is_feature_oriented=True, calc_func=lambda self: len(self.features), agg_func=None, tags="data", ), total_symbols=dict( title="Total Symbols", check_is_symbol_oriented=True, calc_func=lambda self: len(self.symbols), tags="data", ), null_counts=dict( title="Null Counts", calc_func=lambda self, group_by: { k: v.isnull().vbt(wrapper=self.wrapper).sum(group_by=group_by) for k, v in self.data.items() }, agg_func=lambda x: x.sum(), tags="data", ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, column: tp.Optional[tp.Hashable] = None, feature: tp.Optional[tp.Feature] = None, symbol: tp.Optional[tp.Symbol] = None, feature_map: tp.KwargsLike = None, plot_volume: tp.Optional[bool] = None, base: tp.Optional[float] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot either one feature of multiple symbols, or OHLC(V) of one symbol. Args: column (hashable): Name of the feature or symbol to plot. Depends on the data orientation. feature (hashable): Name of the feature to plot. symbol (hashable): Name of the symbol to plot. feature_map (sequence of str): Dictionary mapping the feature names to OHLCV. Applied only if OHLC(V) is plotted. plot_volume (bool): Whether to plot volume beneath. Applied only if OHLC(V) is plotted. base (float): Rebase all series of a feature to a given initial base. !!! note The feature must contain prices. Applied only if lines are plotted. kwargs (dict): Keyword arguments passed to `vectorbtpro.generic.accessors.GenericAccessor.plot` for lines and to `vectorbtpro.ohlcv.accessors.OHLCVDFAccessor.plot` for OHLC(V). Usage: * Plot the lines of one feature across all symbols: ```pycon >>> from vectorbtpro import * >>> start = '2021-01-01 UTC' # crypto is in UTC >>> end = '2021-06-01 UTC' >>> data = vbt.YFData.pull(['BTC-USD', 'ETH-USD', 'ADA-USD'], start=start, end=end) ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> data.plot(feature='Close', base=1).show() ``` * Plot OHLC(V) of one symbol (only if data contains the respective features): ![](/assets/images/api/data_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/data_plot.dark.svg#only-dark){: .iimg loading=lazy } ```pycon >>> data.plot(symbol='BTC-USD').show() ``` ![](/assets/images/api/data_plot_ohlcv.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/data_plot_ohlcv.dark.svg#only-dark){: .iimg loading=lazy } """ if column is not None: if self.feature_oriented: if symbol is not None: raise ValueError("Either column or symbol can be provided, not both") symbol = column else: if feature is not None: raise ValueError("Either column or feature can be provided, not both") feature = column if feature is None and self.has_ohlc: data = self.get(symbols=symbol, squeeze_symbols=True) if isinstance(data, tuple): raise ValueError("Cannot plot OHLC of multiple symbols. Select one symbol.") return data.vbt.ohlcv(feature_map=feature_map).plot(plot_volume=plot_volume, **kwargs) data = self.get(features=feature, symbols=symbol, squeeze_features=True, squeeze_symbols=True) if isinstance(data, tuple): raise ValueError("Cannot plot multiple features and symbols. Select one feature or symbol.") if base is not None: data = data.vbt.rebase(base) return data.vbt.lineplot(**kwargs) @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Data.plots`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.data`.""" return merge_dicts(Analyzable.plots_defaults.__get__(self), self.get_base_settings()["plots"]) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=RepEval( """ if symbols is None: symbols = self.symbols if not self.has_multiple_keys(symbols): symbols = [symbols] [ dict( check_is_not_grouped=True, plot_func="plot", plot_volume=False, symbol=s, title=s, pass_add_trace_kwargs=True, xaxis_kwargs=dict(rangeslider_visible=False, showgrid=True), yaxis_kwargs=dict(showgrid=True), tags="data", ) for s in symbols ]""", context=dict(symbols=None), ) ), ) @property def subplots(self) -> Config: return self._subplots # ############# Docs ############# # @classmethod def build_feature_config_doc(cls, source_cls: tp.Optional[type] = None) -> str: """Build feature config documentation.""" if source_cls is None: source_cls = Data return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "feature_config").__doc__)).substitute( {"feature_config": cls.feature_config.prettify(), "cls_name": cls.__name__}, ) @classmethod def override_feature_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: """Call this method on each subclass that overrides `Data.feature_config`.""" __pdoc__[cls.__name__ + ".feature_config"] = cls.build_feature_config_doc(source_cls=source_cls) Data.override_feature_config_doc(__pdoc__) Data.override_metrics_doc(__pdoc__) Data.override_subplots_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class decorators for data.""" from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.config import copy_dict __all__ = [] def attach_symbol_dict_methods(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: """Class decorator to attach methods for updating symbol dictionaries.""" checks.assert_subclass_of(cls, "Data") DataT = tp.TypeVar("DataT", bound="Data") for target_name in cls._key_dict_attrs: def select_method(self, key: tp.Key, _target_name=target_name, **kwargs) -> tp.Any: if _target_name.endswith("_kwargs"): return self.select_key_kwargs( key, getattr(self, _target_name), kwargs_name=_target_name, **kwargs, ) return self.select_key_from_dict( key, getattr(self, _target_name), dct_name=_target_name, **kwargs, ) select_method.__name__ = "select_" + target_name select_method.__module__ = cls.__module__ select_method.__qualname__ = f"{cls.__name__}.{select_method.__name__}" select_method.__doc__ = f"""Select a feature or symbol from `Data.{target_name}`.""" setattr(cls, select_method.__name__, select_method) for target_name in cls._updatable_attrs: def update_method(self: DataT, _target_name=target_name, check_dict_type: bool = True, **kwargs) -> DataT: from vectorbtpro.data.base import key_dict new_kwargs = copy_dict(getattr(self, _target_name)) for s in self.get_keys(type(new_kwargs)): if s not in new_kwargs: new_kwargs[s] = dict() for k, v in kwargs.items(): if check_dict_type: self.check_dict_type(v, k, dict_type=type(new_kwargs)) if type(v) is key_dict or isinstance(v, type(new_kwargs)): for s, _v in v.items(): new_kwargs[s][k] = _v else: for s in new_kwargs: new_kwargs[s][k] = v return self.replace(**{_target_name: new_kwargs}) update_method.__name__ = "update_" + target_name update_method.__module__ = cls.__module__ update_method.__qualname__ = f"{cls.__name__}.{update_method.__name__}" update_method.__doc__ = f"""Update `Data.{target_name}`. Returns a new instance.""" setattr(cls, update_method.__name__, update_method) return cls # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for generating data. Provides an arsenal of Numba-compiled functions that are used to generate data. These only accept NumPy arrays and other Numba-compatible types.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.registries.jit_registry import register_jitted __all__ = [] @register_jitted(cache=True) def generate_random_data_1d_nb( n_rows: int, start_value: float = 100.0, mean: float = 0.0, std: float = 0.01, symmetric: bool = False, ) -> tp.Array1d: """Generate data using cumulative product of returns drawn from normal (Gaussian) distribution. Turn on `symmetric` to diminish negative returns and make them symmetric to positive ones. Otherwise, the majority of generated paths will go downward.""" out = np.empty(n_rows, dtype=float_) for i in range(n_rows): if i == 0: prev_value = start_value else: prev_value = out[i - 1] return_ = np.random.normal(mean, std) if symmetric and return_ < 0: return_ = -abs(return_) / (1 + abs(return_)) out[i] = prev_value * (1 + return_) return out @register_jitted(cache=True, tags={"can_parallel"}) def generate_random_data_nb( shape: tp.Shape, start_value: tp.FlexArray1dLike = 100.0, mean: tp.FlexArray1dLike = 0.0, std: tp.FlexArray1dLike = 0.01, symmetric: tp.FlexArray1dLike = False, ) -> tp.Array2d: """2-dim version of `generate_random_data_1d_nb`. Each argument can be provided per column thanks to flexible indexing.""" start_value_ = to_1d_array_nb(np.asarray(start_value)) mean_ = to_1d_array_nb(np.asarray(mean)) std_ = to_1d_array_nb(np.asarray(std)) symmetric_ = to_1d_array_nb(np.asarray(symmetric)) out = np.empty(shape, dtype=float_) for col in prange(shape[1]): out[:, col] = generate_random_data_1d_nb( shape[0], start_value=flex_select_1d_pc_nb(start_value_, col), mean=flex_select_1d_pc_nb(mean_, col), std=flex_select_1d_pc_nb(std_, col), symmetric=flex_select_1d_pc_nb(symmetric_, col), ) return out @register_jitted(cache=True) def generate_gbm_data_1d_nb( n_rows: int, start_value: float = 100.0, mean: float = 0.0, std: float = 0.01, dt: float = 1.0, ) -> tp.Array2d: """Generate data using Geometric Brownian Motion (GBM).""" out = np.empty(n_rows, dtype=float_) for i in range(n_rows): if i == 0: prev_value = start_value else: prev_value = out[i - 1] rand = np.random.standard_normal() out[i] = prev_value * np.exp((mean - 0.5 * std**2) * dt + std * np.sqrt(dt) * rand) return out @register_jitted(cache=True, tags={"can_parallel"}) def generate_gbm_data_nb( shape: tp.Shape, start_value: tp.FlexArray1dLike = 100.0, mean: tp.FlexArray1dLike = 0.0, std: tp.FlexArray1dLike = 0.01, dt: tp.FlexArray1dLike = 1.0, ) -> tp.Array2d: """2-dim version of `generate_gbm_data_1d_nb`. Each argument can be provided per column thanks to flexible indexing.""" start_value_ = to_1d_array_nb(np.asarray(start_value)) mean_ = to_1d_array_nb(np.asarray(mean)) std_ = to_1d_array_nb(np.asarray(std)) dt_ = to_1d_array_nb(np.asarray(dt)) out = np.empty(shape, dtype=float_) for col in prange(shape[1]): out[:, col] = generate_gbm_data_1d_nb( shape[0], start_value=flex_select_1d_pc_nb(start_value_, col), mean=flex_select_1d_pc_nb(mean_, col), std=flex_select_1d_pc_nb(std_, col), dt=flex_select_1d_pc_nb(dt_, col), ) return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for scheduling data saves.""" import logging from vectorbtpro import _typing as tp from vectorbtpro.data.base import Data from vectorbtpro.data.updater import DataUpdater from vectorbtpro.utils.config import merge_dicts __all__ = [ "DataSaver", "CSVDataSaver", "HDFDataSaver", "SQLDataSaver", "DuckDBDataSaver", ] logger = logging.getLogger(__name__) class DataSaver(DataUpdater): """Base class for scheduling data saves. Subclasses `vectorbtpro.data.updater.DataUpdater`. Args: data (Data): Data instance. save_kwargs (dict): Default keyword arguments for `DataSaver.init_save_data` and `DataSaver.save_data`. init_save_kwargs (dict): Default keyword arguments overriding `save_kwargs` for `DataSaver.init_save_data`. **kwargs: Keyword arguments passed to the constructor of `DataUpdater`. """ def __init__( self, data: Data, save_kwargs: tp.KwargsLike = None, init_save_kwargs: tp.KwargsLike = None, **kwargs, ) -> None: DataUpdater.__init__( self, data=data, save_kwargs=save_kwargs, init_save_kwargs=init_save_kwargs, **kwargs, ) self._save_kwargs = save_kwargs self._init_save_kwargs = init_save_kwargs @property def save_kwargs(self) -> tp.KwargsLike: """Keyword arguments passed to `DataSaver.save_data`.""" return self._save_kwargs @property def init_save_kwargs(self) -> tp.KwargsLike: """Keyword arguments passed to `DataSaver.init_save_data`.""" return self._init_save_kwargs def init_save_data(self, **kwargs) -> None: """Save initial data. This is an abstract method - override it to define custom logic.""" raise NotImplementedError def save_data(self, **kwargs) -> None: """Save data. This is an abstract method - override it to define custom logic.""" raise NotImplementedError def update(self, save_kwargs: tp.KwargsLike = None, **kwargs) -> None: """Update and save data using `DataSaver.save_data`. Override to do pre- and postprocessing. To stop this method from running again, raise `vectorbtpro.utils.schedule_.CancelledError`.""" # In case the method was called by the user kwargs = merge_dicts( dict(save_kwargs=self.save_kwargs), self.update_kwargs, {"save_kwargs": save_kwargs, **kwargs}, ) save_kwargs = kwargs.pop("save_kwargs") self._data = self.data.update(concat=False, **kwargs) self.update_config(data=self.data) if save_kwargs is None: save_kwargs = {} self.save_data(**save_kwargs) def update_every( self, *args, save_kwargs: tp.KwargsLike = None, init_save: bool = False, init_save_kwargs: tp.KwargsLike = None, **kwargs, ) -> None: """Overrides `vectorbtpro.data.updater.DataUpdater` to save initial data prior to updating.""" if init_save: init_save_kwargs = merge_dicts( self.save_kwargs, save_kwargs, self.init_save_kwargs, init_save_kwargs, ) self.init_save_data(**init_save_kwargs) DataUpdater.update_every(self, *args, save_kwargs=save_kwargs, **kwargs) class CSVDataSaver(DataSaver): """Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_csv`.""" def init_save_data(self, **to_csv_kwargs) -> None: """Save initial data.""" # In case the method was called by the user to_csv_kwargs = merge_dicts( self.save_kwargs, self.init_save_kwargs, to_csv_kwargs, ) self.data.to_csv(**to_csv_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") def save_data(self, **to_csv_kwargs) -> None: """Save data. By default, appends new data without header.""" # In case the method was called by the user to_csv_kwargs = merge_dicts( dict(mode="a", header=False), self.save_kwargs, to_csv_kwargs, ) self.data.to_csv(**to_csv_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") class HDFDataSaver(DataSaver): """Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_hdf`.""" def init_save_data(self, **to_hdf_kwargs) -> None: """Save initial data.""" # In case the method was called by the user to_hdf_kwargs = merge_dicts( self.save_kwargs, self.init_save_kwargs, to_hdf_kwargs, ) self.data.to_hdf(**to_hdf_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") def save_data(self, **to_hdf_kwargs) -> None: """Save data. By default, appends new data in a table format.""" # In case the method was called by the user to_hdf_kwargs = merge_dicts( dict(mode="a", append=True), self.save_kwargs, to_hdf_kwargs, ) self.data.to_hdf(**to_hdf_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") class SQLDataSaver(DataSaver): """Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_sql`.""" def init_save_data(self, **to_sql_kwargs) -> None: """Save initial data.""" # In case the method was called by the user to_sql_kwargs = merge_dicts( self.save_kwargs, self.init_save_kwargs, to_sql_kwargs, ) self.data.to_sql(**to_sql_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") def save_data(self, **to_sql_kwargs) -> None: """Save data. By default, appends new data without header.""" # In case the method was called by the user to_sql_kwargs = merge_dicts( dict(if_exists="append"), self.save_kwargs, to_sql_kwargs, ) self.data.to_sql(**to_sql_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") class DuckDBDataSaver(DataSaver): """Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_duckdb`.""" def init_save_data(self, **to_duckdb_kwargs) -> None: """Save initial data.""" # In case the method was called by the user to_duckdb_kwargs = merge_dicts( self.save_kwargs, self.init_save_kwargs, to_duckdb_kwargs, ) self.data.to_duckdb(**to_duckdb_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") def save_data(self, **to_duckdb_kwargs) -> None: """Save data. By default, appends new data without header.""" # In case the method was called by the user to_duckdb_kwargs = merge_dicts( dict(if_exists="append"), self.save_kwargs, to_duckdb_kwargs, ) self.data.to_duckdb(**to_duckdb_kwargs) new_index = self.data.wrapper.index logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for scheduling data updates.""" import logging from vectorbtpro import _typing as tp from vectorbtpro.data.base import Data from vectorbtpro.utils.config import Configured, merge_dicts from vectorbtpro.utils.schedule_ import ScheduleManager __all__ = [ "DataUpdater", ] logger = logging.getLogger(__name__) class DataUpdater(Configured): """Base class for scheduling data updates. Args: data (Data): Data instance. update_kwargs (dict): Default keyword arguments for `DataSaver.update`. **kwargs: Keyword arguments passed to the constructor of `Configured`. """ def __init__( self, data: Data, schedule_manager: tp.Optional[ScheduleManager] = None, update_kwargs: tp.KwargsLike = None, **kwargs, ) -> None: if schedule_manager is None: schedule_manager = ScheduleManager() Configured.__init__( self, data=data, schedule_manager=schedule_manager, update_kwargs=update_kwargs, **kwargs, ) self._data = data self._schedule_manager = schedule_manager self._update_kwargs = update_kwargs @property def data(self) -> Data: """Data instance. See `vectorbtpro.data.base.Data`.""" return self._data @property def schedule_manager(self) -> ScheduleManager: """Schedule manager instance. See `vectorbtpro.utils.schedule_.ScheduleManager`.""" return self._schedule_manager @property def update_kwargs(self) -> tp.KwargsLike: """Keyword arguments passed to `DataSaver.update`.""" return self._update_kwargs def update(self, **kwargs) -> None: """Method that updates data. Override to do pre- and postprocessing. To stop this method from running again, raise `vectorbtpro.utils.schedule_.CancelledError`.""" # In case the method was called by the user kwargs = merge_dicts(self.update_kwargs, kwargs) self._data = self.data.update(**kwargs) self.update_config(data=self.data) new_index = self.data.wrapper.index logger.info(f"New data has {len(new_index)} rows from {new_index[0]} to {new_index[-1]}") def update_every( self, *args, to: int = None, tags: tp.Optional[tp.Iterable[tp.Hashable]] = None, in_background: bool = False, replace: bool = True, start: bool = True, start_kwargs: tp.KwargsLike = None, **update_kwargs, ) -> None: """Schedule `DataUpdater.update` as a job. For `*args`, `to` and `tags`, see `vectorbtpro.utils.schedule_.ScheduleManager.every`. If `in_background` is set to True, starts in the background as an `asyncio` task. The task can be stopped with `vectorbtpro.utils.schedule_.ScheduleManager.stop`. If `replace` is True, will delete scheduled jobs with the same tags, or all jobs if tags are omitted. If `start` is False, will add the job to the scheduler without starting. `**update_kwargs` are merged over `DataUpdater.update_kwargs` and passed to `DataUpdater.update`.""" if replace: self.schedule_manager.clear_jobs(tags) update_kwargs = merge_dicts(self.update_kwargs, update_kwargs) self.schedule_manager.every(*args, to=to, tags=tags).do(self.update, **update_kwargs) if start: if start_kwargs is None: start_kwargs = {} if in_background: self.schedule_manager.start_in_background(**start_kwargs) else: self.schedule_manager.start(**start_kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for generic data. Provides an arsenal of Numba-compiled functions that are used by accessors and in many other parts of a backtesting pipeline, such as technical indicators. These only accept NumPy arrays and other Numba-compatible types. !!! note vectorbt treats matrices as first-class citizens and expects input arrays to be 2-dim, unless function has suffix `_1d` or is meant to be input to another function. Data is processed along index (axis 0). Rolling functions with `minp=None` have `min_periods` set to the window size. All functions passed as argument must be Numba-compiled. Records must retain the order they were created in. !!! warning Make sure to use `parallel=True` only if your columns are independent. """ from vectorbtpro.generic.nb.apply_reduce import * from vectorbtpro.generic.nb.base import * from vectorbtpro.generic.nb.iter_ import * from vectorbtpro.generic.nb.patterns import * from vectorbtpro.generic.nb.records import * from vectorbtpro.generic.nb.rolling import * from vectorbtpro.generic.nb.sim_range import * __all__ = [] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for mapping, applying, and reducing.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.generic.nb.base import nancov_1d_nb, nanstd_1d_nb, nancorr_1d_nb from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch # ############# Map, apply, and reduce ############# # @register_jitted def map_1d_nb(arr: tp.Array1d, map_func_nb: tp.MapFunc, *args) -> tp.Array1d: """Map elements element-wise using `map_func_nb`. `map_func_nb` must accept the element and `*args`. Must return a single value.""" i_0_out = map_func_nb(arr[0], *args) out = np.empty_like(arr, dtype=np.asarray(i_0_out).dtype) out[0] = i_0_out for i in range(1, arr.shape[0]): out[i] = map_func_nb(arr[i], *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), map_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def map_nb(arr: tp.Array2d, map_func_nb: tp.MapFunc, *args) -> tp.Array2d: """2-dim version of `map_1d_nb`.""" col_0_out = map_1d_nb(arr[:, 0], map_func_nb, *args) out = np.empty_like(arr, dtype=col_0_out.dtype) out[:, 0] = col_0_out for col in prange(1, out.shape[1]): out[:, col] = map_1d_nb(arr[:, col], map_func_nb, *args) return out @register_jitted def map_1d_meta_nb(n: int, col: int, map_func_nb: tp.MapMetaFunc, *args) -> tp.Array1d: """Meta version of `map_1d_nb`. `map_func_nb` must accept the row index, the column index, and `*args`. Must return a single value.""" i_0_out = map_func_nb(0, col, *args) out = np.empty(n, dtype=np.asarray(i_0_out).dtype) out[0] = i_0_out for i in range(1, n): out[i] = map_func_nb(i, col, *args) return out @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), map_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def map_meta_nb(target_shape: tp.Shape, map_func_nb: tp.MapMetaFunc, *args) -> tp.Array2d: """2-dim version of `map_1d_meta_nb`.""" col_0_out = map_1d_meta_nb(target_shape[0], 0, map_func_nb, *args) out = np.empty(target_shape, dtype=col_0_out.dtype) out[:, 0] = col_0_out for col in prange(1, out.shape[1]): out[:, col] = map_1d_meta_nb(target_shape[0], col, map_func_nb, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), apply_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def apply_nb(arr: tp.Array2d, apply_func_nb: tp.ApplyFunc, *args) -> tp.Array2d: """Apply function on each column of an object. `apply_func_nb` must accept the array and `*args`. Must return a single value or an array of shape `a.shape[1]`.""" col_0_out = apply_func_nb(arr[:, 0], *args) out = np.empty_like(arr, dtype=np.asarray(col_0_out).dtype) out[:, 0] = col_0_out for col in prange(1, arr.shape[1]): out[:, col] = apply_func_nb(arr[:, col], *args) return out @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), apply_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def apply_meta_nb(target_shape: tp.Shape, apply_func_nb: tp.ApplyMetaFunc, *args) -> tp.Array2d: """Meta version of `apply_nb` that prepends the column index to the arguments of `apply_func_nb`.""" col_0_out = apply_func_nb(0, *args) out = np.empty(target_shape, dtype=np.asarray(col_0_out).dtype) out[:, 0] = col_0_out for col in prange(1, target_shape[1]): out[:, col] = apply_func_nb(col, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=0), arg_take_spec=dict( arr=ch.ArraySlicer(axis=0), apply_func_nb=None, args=ch.ArgsTaker(), ), merge_func="row_stack", ) @register_jitted(tags={"can_parallel"}) def row_apply_nb(arr: tp.Array2d, apply_func_nb: tp.ApplyFunc, *args) -> tp.Array2d: """`apply_nb` but applied on rows rather than columns.""" row_0_out = apply_func_nb(arr[0, :], *args) out = np.empty_like(arr, dtype=np.asarray(row_0_out).dtype) out[0, :] = row_0_out for i in prange(1, arr.shape[0]): out[i, :] = apply_func_nb(arr[i, :], *args) return out @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=0), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=0), apply_func_nb=None, args=ch.ArgsTaker(), ), merge_func="row_stack", ) @register_jitted(tags={"can_parallel"}) def row_apply_meta_nb(target_shape: tp.Shape, apply_func_nb: tp.ApplyMetaFunc, *args) -> tp.Array2d: """Meta version of `row_apply_nb` that prepends the row index to the arguments of `apply_func_nb`.""" row_0_out = apply_func_nb(0, *args) out = np.empty(target_shape, dtype=np.asarray(row_0_out).dtype) out[0, :] = row_0_out for i in prange(1, target_shape[0]): out[i, :] = apply_func_nb(i, *args) return out @register_jitted def rolling_reduce_1d_nb( arr: tp.Array1d, window: int, minp: tp.Optional[int], reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array1d: """Provide rolling window calculations. `reduce_func_nb` must accept the array and `*args`. Must return a single value.""" if minp is None: minp = window out = np.empty_like(arr, dtype=float_) nancnt_arr = np.empty(arr.shape[0], dtype=int_) nancnt = 0 for i in range(arr.shape[0]): if np.isnan(arr[i]): nancnt = nancnt + 1 nancnt_arr[i] = nancnt if i < window: valid_cnt = i + 1 - nancnt else: valid_cnt = window - (nancnt - nancnt_arr[i - window]) if valid_cnt < minp: out[i] = np.nan else: from_i = max(0, i + 1 - window) to_i = i + 1 arr_window = arr[from_i:to_i] out[i] = reduce_func_nb(arr_window, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), window=None, minp=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_reduce_nb( arr: tp.Array2d, window: int, minp: tp.Optional[int], reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array2d: """2-dim version of `rolling_reduce_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_reduce_1d_nb(arr[:, col], window, minp, reduce_func_nb, *args) return out @register_jitted def rolling_reduce_two_1d_nb( arr1: tp.Array1d, arr2: tp.Array1d, window: int, minp: tp.Optional[int], reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array1d: """Provide rolling window calculations for two arrays. `reduce_func_nb` must accept two arrays and `*args`. Must return a single value.""" if minp is None: minp = window out = np.empty_like(arr1, dtype=float_) nancnt_arr = np.empty(arr1.shape[0], dtype=int_) nancnt = 0 for i in range(arr1.shape[0]): if np.isnan(arr1[i]) or np.isnan(arr2[i]): nancnt = nancnt + 1 nancnt_arr[i] = nancnt if i < window: valid_cnt = i + 1 - nancnt else: valid_cnt = window - (nancnt - nancnt_arr[i - window]) if valid_cnt < minp: out[i] = np.nan else: from_i = max(0, i + 1 - window) to_i = i + 1 arr1_window = arr1[from_i:to_i] arr2_window = arr2[from_i:to_i] out[i] = reduce_func_nb(arr1_window, arr2_window, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), window=None, minp=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_reduce_two_nb( arr1: tp.Array2d, arr2: tp.Array2d, window: int, minp: tp.Optional[int], reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array2d: """2-dim version of `rolling_reduce_two_1d_nb`.""" out = np.empty_like(arr1, dtype=float_) for col in prange(arr1.shape[1]): out[:, col] = rolling_reduce_two_1d_nb(arr1[:, col], arr2[:, col], window, minp, reduce_func_nb, *args) return out @register_jitted def rolling_reduce_1d_meta_nb( n: int, col: int, window: int, minp: tp.Optional[int], reduce_func_nb: tp.RangeReduceMetaFunc, *args, ) -> tp.Array1d: """Meta version of `rolling_reduce_1d_nb`. `reduce_func_nb` must accept the start row index, the end row index, the column, and `*args`. Must return a single value.""" if minp is None: minp = window out = np.empty(n, dtype=float_) for i in range(n): valid_cnt = min(i + 1, window) if valid_cnt < minp: out[i] = np.nan else: from_i = max(0, i + 1 - window) to_i = i + 1 out[i] = reduce_func_nb(from_i, to_i, col, *args) return out @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), window=None, minp=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_reduce_meta_nb( target_shape: tp.Shape, window: int, minp: tp.Optional[int], reduce_func_nb: tp.RangeReduceMetaFunc, *args, ) -> tp.Array2d: """2-dim version of `rolling_reduce_1d_meta_nb`.""" out = np.empty(target_shape, dtype=float_) for col in prange(target_shape[1]): out[:, col] = rolling_reduce_1d_meta_nb(target_shape[0], col, window, minp, reduce_func_nb, *args) return out @register_jitted def rolling_freq_reduce_1d_nb( index: tp.Array1d, arr: tp.Array1d, freq: np.timedelta64, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array1d: """Provide rolling, frequency-based window calculations. `reduce_func_nb` must accept the array and `*args`. Must return a single value.""" out = np.empty_like(arr, dtype=float_) from_i = 0 for i in range(arr.shape[0]): if index[from_i] <= index[i] - freq: for j in range(from_i + 1, index.shape[0]): if index[j] > index[i] - freq: from_i = j break to_i = i + 1 arr_window = arr[from_i:to_i] out[i] = reduce_func_nb(arr_window, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( index=None, arr=ch.ArraySlicer(axis=1), freq=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_freq_reduce_nb( index: tp.Array1d, arr: tp.Array2d, freq: np.timedelta64, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array2d: """2-dim version of `rolling_reduce_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_freq_reduce_1d_nb(index, arr[:, col], freq, reduce_func_nb, *args) return out @register_jitted def rolling_freq_reduce_1d_meta_nb( col: int, index: tp.Array1d, freq: np.timedelta64, reduce_func_nb: tp.RangeReduceMetaFunc, *args, ) -> tp.Array1d: """Meta version of `rolling_freq_reduce_1d_nb`. `reduce_func_nb` must accept the start row index, the end row index, the column, and `*args`. Must return a single value.""" out = np.empty(index.shape[0], dtype=float_) from_i = 0 for i in range(index.shape[0]): if index[from_i] <= index[i] - freq: for j in range(from_i + 1, index.shape[0]): if index[j] > index[i] - freq: from_i = j break to_i = i + 1 out[i] = reduce_func_nb(from_i, to_i, col, *args) return out @register_chunkable( size=ch.ArgSizer(arg_query="n_cols"), arg_take_spec=dict( n_cols=ch.CountAdapter(), index=None, freq=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_freq_reduce_meta_nb( n_cols: int, index: tp.Array1d, freq: np.timedelta64, reduce_func_nb: tp.RangeReduceMetaFunc, *args, ) -> tp.Array2d: """2-dim version of `rolling_freq_reduce_1d_meta_nb`.""" out = np.empty((index.shape[0], n_cols), dtype=float_) for col in prange(n_cols): out[:, col] = rolling_freq_reduce_1d_meta_nb(col, index, freq, reduce_func_nb, *args) return out @register_jitted def groupby_reduce_1d_nb(arr: tp.Array1d, group_map: tp.GroupMap, reduce_func_nb: tp.ReduceFunc, *args) -> tp.Array1d: """Provide group-by reduce calculations. `reduce_func_nb` must accept the array and `*args`. Must return a single value.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = reduce_func_nb(arr[group_0_idxs], *args) out = np.empty(group_lens.shape[0], dtype=np.asarray(group_0_out).dtype) out[0] = group_0_out for group in range(1, group_lens.shape[0]): group_len = group_lens[group] start_idx = group_start_idxs[group] idxs = group_idxs[start_idx : start_idx + group_len] out[group] = reduce_func_nb(arr[idxs], *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), group_map=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def groupby_reduce_nb(arr: tp.Array2d, group_map: tp.GroupMap, reduce_func_nb: tp.ReduceFunc, *args) -> tp.Array2d: """2-dim version of `groupby_reduce_1d_nb`.""" col_0_out = groupby_reduce_1d_nb(arr[:, 0], group_map, reduce_func_nb, *args) out = np.empty((col_0_out.shape[0], arr.shape[1]), dtype=col_0_out.dtype) out[:, 0] = col_0_out for col in prange(1, arr.shape[1]): out[:, col] = groupby_reduce_1d_nb(arr[:, col], group_map, reduce_func_nb, *args) return out @register_jitted def groupby_reduce_1d_meta_nb( col: int, group_map: tp.GroupMap, reduce_func_nb: tp.GroupByReduceMetaFunc, *args, ) -> tp.Array1d: """Meta version of `groupby_reduce_1d_nb`. `reduce_func_nb` must accept the array of indices in the group, the group index, the column index, and `*args`. Must return a single value.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = reduce_func_nb(group_0_idxs, 0, col, *args) out = np.empty(group_lens.shape[0], dtype=np.asarray(group_0_out).dtype) out[0] = group_0_out for group in range(1, group_lens.shape[0]): group_len = group_lens[group] start_idx = group_start_idxs[group] idxs = group_idxs[start_idx : start_idx + group_len] out[group] = reduce_func_nb(idxs, group, col, *args) return out @register_chunkable( size=ch.ArgSizer(arg_query="n_cols"), arg_take_spec=dict( n_cols=ch.CountAdapter(), group_map=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def groupby_reduce_meta_nb( n_cols: int, group_map: tp.GroupMap, reduce_func_nb: tp.GroupByReduceMetaFunc, *args, ) -> tp.Array2d: """2-dim version of `groupby_reduce_1d_meta_nb`.""" col_0_out = groupby_reduce_1d_meta_nb(0, group_map, reduce_func_nb, *args) out = np.empty((col_0_out.shape[0], n_cols), dtype=col_0_out.dtype) out[:, 0] = col_0_out for col in prange(1, n_cols): out[:, col] = groupby_reduce_1d_meta_nb(col, group_map, reduce_func_nb, *args) return out @register_jitted(tags={"can_parallel"}) def groupby_transform_nb( arr: tp.Array2d, group_map: tp.GroupMap, transform_func_nb: tp.GroupByTransformFunc, *args, ) -> tp.Array2d: """Provide group-by transform calculations. `transform_func_nb` must accept the 2-dim array of the group and `*args`. Must return a scalar or an array that broadcasts against the group array's shape.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = transform_func_nb(arr[group_0_idxs], *args) out = np.empty(arr.shape, dtype=np.asarray(group_0_out).dtype) out[group_0_idxs] = group_0_out for group in prange(1, group_lens.shape[0]): group_len = group_lens[group] start_idx = group_start_idxs[group] idxs = group_idxs[start_idx : start_idx + group_len] out[idxs] = transform_func_nb(arr[idxs], *args) return out @register_jitted(tags={"can_parallel"}) def groupby_transform_meta_nb( target_shape: tp.Shape, group_map: tp.GroupMap, transform_func_nb: tp.GroupByTransformMetaFunc, *args, ) -> tp.Array2d: """Meta version of `groupby_transform_nb`. `transform_func_nb` must accept the array of indices in the group, the group index, and `*args`. Must return a scalar or an array that broadcasts against the group's shape.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = transform_func_nb(group_0_idxs, 0, *args) out = np.empty(target_shape, dtype=np.asarray(group_0_out).dtype) out[group_0_idxs] = group_0_out for group in prange(1, group_lens.shape[0]): group_len = group_lens[group] start_idx = group_start_idxs[group] idxs = group_idxs[start_idx : start_idx + group_len] out[idxs] = transform_func_nb(idxs, group, *args) return out @register_jitted def reduce_index_ranges_1d_nb( arr: tp.Array1d, range_starts: tp.Array1d, range_ends: tp.Array1d, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array1d: """Reduce each index range. `reduce_func_nb` must accept the array and `*args`. Must return a single value.""" out = np.empty(range_starts.shape[0], dtype=float_) for k in range(len(range_starts)): from_i = range_starts[k] to_i = range_ends[k] if from_i == -1 or to_i == -1: out[k] = np.nan else: out[k] = reduce_func_nb(arr[from_i:to_i], *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), range_starts=None, range_ends=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_index_ranges_nb( arr: tp.Array2d, range_starts: tp.Array1d, range_ends: tp.Array1d, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array2d: """2-dim version of `reduce_index_ranges_1d_nb`.""" out = np.empty((range_starts.shape[0], arr.shape[1]), dtype=float_) for col in prange(arr.shape[1]): out[:, col] = reduce_index_ranges_1d_nb(arr[:, col], range_starts, range_ends, reduce_func_nb, *args) return out @register_jitted def reduce_index_ranges_1d_meta_nb( col: int, range_starts: tp.Array1d, range_ends: tp.Array1d, reduce_func_nb: tp.RangeReduceMetaFunc, *args, ) -> tp.Array1d: """Meta version of `reduce_index_ranges_1d_nb`. `reduce_func_nb` must accept the start row index, the end row index, the column, and `*args`. Must return a single value.""" out = np.empty(range_starts.shape[0], dtype=float_) for k in range(len(range_starts)): from_i = range_starts[k] to_i = range_ends[k] if from_i == -1 or to_i == -1: out[k] = np.nan else: out[k] = reduce_func_nb(from_i, to_i, col, *args) return out @register_chunkable( size=ch.ArgSizer(arg_query="n_cols"), arg_take_spec=dict( n_cols=ch.CountAdapter(), range_starts=None, range_ends=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_index_ranges_meta_nb( n_cols: int, range_starts: tp.Array1d, range_ends: tp.Array1d, reduce_func_nb: tp.RangeReduceMetaFunc, *args, ) -> tp.Array2d: """2-dim version of `reduce_index_ranges_1d_meta_nb`.""" out = np.empty((range_starts.shape[0], n_cols), dtype=float_) for col in prange(n_cols): out[:, col] = reduce_index_ranges_1d_meta_nb(col, range_starts, range_ends, reduce_func_nb, *args) return out @register_jitted def apply_and_reduce_1d_nb( arr: tp.Array1d, apply_func_nb: tp.ApplyFunc, apply_args: tuple, reduce_func_nb: tp.ReduceFunc, reduce_args: tuple, ) -> tp.Scalar: """Apply `apply_func_nb` and reduce into a single value using `reduce_func_nb`. `apply_func_nb` must accept the array and `*apply_args`. Must return an array. `reduce_func_nb` must accept the array of results from `apply_func_nb` and `*reduce_args`. Must return a single value.""" temp = apply_func_nb(arr, *apply_args) return reduce_func_nb(temp, *reduce_args) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), apply_func_nb=None, apply_args=ch.ArgsTaker(), reduce_func_nb=None, reduce_args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def apply_and_reduce_nb( arr: tp.Array2d, apply_func_nb: tp.ApplyFunc, apply_args: tuple, reduce_func_nb: tp.ReduceFunc, reduce_args: tuple, ) -> tp.Array1d: """2-dim version of `apply_and_reduce_1d_nb`.""" col_0_out = apply_and_reduce_1d_nb(arr[:, 0], apply_func_nb, apply_args, reduce_func_nb, reduce_args) out = np.empty(arr.shape[1], dtype=np.asarray(col_0_out).dtype) out[0] = col_0_out for col in prange(1, arr.shape[1]): out[col] = apply_and_reduce_1d_nb(arr[:, col], apply_func_nb, apply_args, reduce_func_nb, reduce_args) return out @register_jitted def apply_and_reduce_1d_meta_nb( col: int, apply_func_nb: tp.ApplyMetaFunc, apply_args: tuple, reduce_func_nb: tp.ReduceMetaFunc, reduce_args: tuple, ) -> tp.Scalar: """Meta version of `apply_and_reduce_1d_nb`. `apply_func_nb` must accept the column index, the array, and `*apply_args`. Must return an array. `reduce_func_nb` must accept the column index, the array of results from `apply_func_nb`, and `*reduce_args`. Must return a single value.""" temp = apply_func_nb(col, *apply_args) return reduce_func_nb(col, temp, *reduce_args) @register_chunkable( size=ch.ArgSizer(arg_query="n_cols"), arg_take_spec=dict( n_cols=ch.CountAdapter(), apply_func_nb=None, apply_args=ch.ArgsTaker(), reduce_func_nb=None, reduce_args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def apply_and_reduce_meta_nb( n_cols: int, apply_func_nb: tp.ApplyMetaFunc, apply_args: tuple, reduce_func_nb: tp.ReduceMetaFunc, reduce_args: tuple, ) -> tp.Array1d: """2-dim version of `apply_and_reduce_1d_meta_nb`.""" col_0_out = apply_and_reduce_1d_meta_nb(0, apply_func_nb, apply_args, reduce_func_nb, reduce_args) out = np.empty(n_cols, dtype=np.asarray(col_0_out).dtype) out[0] = col_0_out for col in prange(1, n_cols): out[col] = apply_and_reduce_1d_meta_nb(col, apply_func_nb, apply_args, reduce_func_nb, reduce_args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_nb(arr: tp.Array2d, reduce_func_nb: tp.ReduceFunc, *args) -> tp.Array1d: """Reduce each column into a single value using `reduce_func_nb`. `reduce_func_nb` must accept the array and `*args`. Must return a single value.""" col_0_out = reduce_func_nb(arr[:, 0], *args) out = np.empty(arr.shape[1], dtype=np.asarray(col_0_out).dtype) out[0] = col_0_out for col in prange(1, arr.shape[1]): out[col] = reduce_func_nb(arr[:, col], *args) return out @register_chunkable( size=ch.ArgSizer(arg_query="n_cols"), arg_take_spec=dict( n_cols=ch.CountAdapter(), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_meta_nb(n_cols: int, reduce_func_nb: tp.ReduceMetaFunc, *args) -> tp.Array1d: """Meta version of `reduce_nb`. `reduce_func_nb` must accept the column index and `*args`. Must return a single value.""" col_0_out = reduce_func_nb(0, *args) out = np.empty(n_cols, dtype=np.asarray(col_0_out).dtype) out[0] = col_0_out for col in prange(1, n_cols): out[col] = reduce_func_nb(col, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_to_array_nb(arr: tp.Array2d, reduce_func_nb: tp.ReduceToArrayFunc, *args) -> tp.Array2d: """Same as `reduce_nb` but `reduce_func_nb` must return an array.""" col_0_out = reduce_func_nb(arr[:, 0], *args) out = np.empty((col_0_out.shape[0], arr.shape[1]), dtype=col_0_out.dtype) out[:, 0] = col_0_out for col in prange(1, arr.shape[1]): out[:, col] = reduce_func_nb(arr[:, col], *args) return out @register_chunkable( size=ch.ArgSizer(arg_query="n_cols"), arg_take_spec=dict( n_cols=ch.CountAdapter(), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_to_array_meta_nb(n_cols: int, reduce_func_nb: tp.ReduceToArrayMetaFunc, *args) -> tp.Array2d: """Same as `reduce_meta_nb` but `reduce_func_nb` must return an array.""" col_0_out = reduce_func_nb(0, *args) out = np.empty((col_0_out.shape[0], n_cols), dtype=col_0_out.dtype) out[:, 0] = col_0_out for col in prange(1, n_cols): out[:, col] = reduce_func_nb(col, *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), group_map=base_ch.GroupMapSlicer(), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_grouped_nb( arr: tp.Array2d, group_map: tp.GroupMap, reduce_func_nb: tp.ReduceGroupedFunc, *args, ) -> tp.Array1d: """Reduce each group of columns into a single value using `reduce_func_nb`. `reduce_func_nb` must accept the 2-dim array and `*args`. Must return a single value.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = reduce_func_nb(arr[:, group_0_idxs], *args) out = np.empty(len(group_lens), dtype=np.asarray(group_0_out).dtype) out[0] = group_0_out for group in prange(1, len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] out[group] = reduce_func_nb(arr[:, col_idxs], *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( group_map=base_ch.GroupMapSlicer(), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_grouped_meta_nb(group_map: tp.GroupMap, reduce_func_nb: tp.ReduceGroupedMetaFunc, *args) -> tp.Array1d: """Meta version of `reduce_grouped_nb`. `reduce_func_nb` must accept the column indices of the group, the group index, and `*args`. Must return a single value.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = reduce_func_nb(group_0_idxs, 0, *args) out = np.empty(len(group_lens), dtype=np.asarray(group_0_out).dtype) out[0] = group_0_out for group in prange(1, len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] out[group] = reduce_func_nb(col_idxs, group, *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def flatten_forder_nb(arr: tp.Array2d) -> tp.Array1d: """Flatten the array in F order.""" out = np.empty(arr.shape[0] * arr.shape[1], dtype=arr.dtype) for col in prange(arr.shape[1]): out[col * arr.shape[0] : (col + 1) * arr.shape[0]] = arr[:, col] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), group_map=base_ch.GroupMapSlicer(), in_c_order=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_flat_grouped_nb( arr: tp.Array2d, group_map: tp.GroupMap, in_c_order: bool, reduce_func_nb: tp.ReduceToArrayFunc, *args, ) -> tp.Array1d: """Same as `reduce_grouped_nb` but passes flattened array.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] if in_c_order: group_0_out = reduce_func_nb(arr[:, group_0_idxs].flatten(), *args) else: group_0_out = reduce_func_nb(flatten_forder_nb(arr[:, group_0_idxs]), *args) out = np.empty(len(group_lens), dtype=np.asarray(group_0_out).dtype) out[0] = group_0_out for group in prange(1, len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] if in_c_order: out[group] = reduce_func_nb(arr[:, col_idxs].flatten(), *args) else: out[group] = reduce_func_nb(flatten_forder_nb(arr[:, col_idxs]), *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), group_map=base_ch.GroupMapSlicer(), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_grouped_to_array_nb( arr: tp.Array2d, group_map: tp.GroupMap, reduce_func_nb: tp.ReduceGroupedToArrayFunc, *args, ) -> tp.Array2d: """Same as `reduce_grouped_nb` but `reduce_func_nb` must return an array.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = reduce_func_nb(arr[:, group_0_idxs], *args) out = np.empty((group_0_out.shape[0], len(group_lens)), dtype=group_0_out.dtype) out[:, 0] = group_0_out for group in prange(1, len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] out[:, group] = reduce_func_nb(arr[:, col_idxs], *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( group_map=base_ch.GroupMapSlicer(), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_grouped_to_array_meta_nb( group_map: tp.GroupMap, reduce_func_nb: tp.ReduceGroupedToArrayMetaFunc, *args, ) -> tp.Array2d: """Same as `reduce_grouped_meta_nb` but `reduce_func_nb` must return an array.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_0_out = reduce_func_nb(group_0_idxs, 0, *args) out = np.empty((group_0_out.shape[0], len(group_lens)), dtype=group_0_out.dtype) out[:, 0] = group_0_out for group in prange(1, len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] out[:, group] = reduce_func_nb(col_idxs, group, *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), group_map=base_ch.GroupMapSlicer(), in_c_order=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_flat_grouped_to_array_nb( arr: tp.Array2d, group_map: tp.GroupMap, in_c_order: bool, reduce_func_nb: tp.ReduceToArrayFunc, *args, ) -> tp.Array2d: """Same as `reduce_grouped_to_array_nb` but passes flattened array.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] if in_c_order: group_0_out = reduce_func_nb(arr[:, group_0_idxs].flatten(), *args) else: group_0_out = reduce_func_nb(flatten_forder_nb(arr[:, group_0_idxs]), *args) out = np.empty((group_0_out.shape[0], len(group_lens)), dtype=group_0_out.dtype) out[:, 0] = group_0_out for group in prange(1, len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] if in_c_order: out[:, group] = reduce_func_nb(arr[:, col_idxs].flatten(), *args) else: out[:, group] = reduce_func_nb(flatten_forder_nb(arr[:, col_idxs]), *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), group_map=base_ch.GroupMapSlicer(), squeeze_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def squeeze_grouped_nb(arr: tp.Array2d, group_map: tp.GroupMap, squeeze_func_nb: tp.ReduceFunc, *args) -> tp.Array2d: """Squeeze each group of columns into a single column using `squeeze_func_nb`. `squeeze_func_nb` must accept index the array and `*args`. Must return a single value.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_i_0_out = squeeze_func_nb(arr[0][group_0_idxs], *args) out = np.empty((arr.shape[0], len(group_lens)), dtype=np.asarray(group_i_0_out).dtype) out[0, 0] = group_i_0_out for group in prange(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for i in range(arr.shape[0]): if group == 0 and i == 0: continue out[i, group] = squeeze_func_nb(arr[i][col_idxs], *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( n_rows=None, group_map=base_ch.GroupMapSlicer(), squeeze_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def squeeze_grouped_meta_nb( n_rows: int, group_map: tp.GroupMap, squeeze_func_nb: tp.GroupSqueezeMetaFunc, *args, ) -> tp.Array2d: """Meta version of `squeeze_grouped_nb`. `squeeze_func_nb` must accept the row index, the column indices of the group, the group index, and `*args`. Must return a single value.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]] group_i_0_out = squeeze_func_nb(0, group_0_idxs, 0, *args) out = np.empty((n_rows, len(group_lens)), dtype=np.asarray(group_i_0_out).dtype) out[0, 0] = group_i_0_out for group in prange(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for i in range(n_rows): if group == 0 and i == 0: continue out[i, group] = squeeze_func_nb(i, col_idxs, group, *args) return out # ############# Flattening ############# # @register_jitted(cache=True) def flatten_grouped_nb(arr: tp.Array2d, group_map: tp.GroupMap, in_c_order: bool) -> tp.Array2d: """Flatten each group of columns.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens out = np.full((arr.shape[0] * np.max(group_lens), len(group_lens)), np.nan, dtype=float_) max_len = np.max(group_lens) for group in range(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for k in range(group_len): col = col_idxs[k] if in_c_order: out[k::max_len, group] = arr[:, col] else: out[k * arr.shape[0] : (k + 1) * arr.shape[0], group] = arr[:, col] return out @register_jitted(cache=True) def flatten_uniform_grouped_nb(arr: tp.Array2d, group_map: tp.GroupMap, in_c_order: bool) -> tp.Array2d: """Flatten each group of columns of the same length.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens out = np.empty((arr.shape[0] * np.max(group_lens), len(group_lens)), dtype=arr.dtype) max_len = np.max(group_lens) for group in range(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for k in range(group_len): col = col_idxs[k] if in_c_order: out[k::max_len, group] = arr[:, col] else: out[k * arr.shape[0] : (k + 1) * arr.shape[0], group] = arr[:, col] return out # ############# Proximity ############# # @register_jitted(tags={"can_parallel"}) def proximity_reduce_nb( arr: tp.Array2d, window: int, reduce_func_nb: tp.ProximityReduceMetaFunc, *args, ) -> tp.Array2d: """Flatten `window` surrounding rows and columns and reduce them into a single value using `reduce_func_nb`. `reduce_func_nb` must accept the array and `*args`. Must return a single value.""" out = np.empty_like(arr, dtype=float_) for i in prange(arr.shape[0]): for col in range(arr.shape[1]): from_i = max(0, i - window) to_i = min(i + window + 1, arr.shape[0]) from_col = max(0, col - window) to_col = min(col + window + 1, arr.shape[1]) stride_arr = arr[from_i:to_i, from_col:to_col] out[i, col] = reduce_func_nb(stride_arr.flatten(), *args) return out @register_jitted(tags={"can_parallel"}) def proximity_reduce_meta_nb( target_shape: tp.Shape, window: int, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array2d: """Meta version of `proximity_reduce_nb`. `reduce_func_nb` must accept the start row index, the end row index, the start column index, the end column index, and `*args`. Must return a single value.""" out = np.empty(target_shape, dtype=float_) for i in prange(target_shape[0]): for col in range(target_shape[1]): from_i = max(0, i - window) to_i = min(i + window + 1, target_shape[0]) from_col = max(0, col - window) to_col = min(col + window + 1, target_shape[1]) out[i, col] = reduce_func_nb(from_i, to_i, from_col, to_col, *args) return out # ############# Reducers ############# # @register_jitted(cache=True) def nth_reduce_nb(arr: tp.Array1d, n: int) -> float: """Get n-th element.""" if (n < 0 and abs(n) > arr.shape[0]) or n >= arr.shape[0]: raise ValueError("index is out of bounds") return arr[n] @register_jitted(cache=True) def first_reduce_nb(arr: tp.Array1d) -> float: """Get first non-NA element.""" if arr.shape[0] == 0: raise ValueError("index is out of bounds") for i in range(len(arr)): if not np.isnan(arr[i]): return arr[i] return np.nan @register_jitted(cache=True) def last_reduce_nb(arr: tp.Array1d) -> float: """Get last non-NA element.""" if arr.shape[0] == 0: raise ValueError("index is out of bounds") for i in range(len(arr) - 1, -1, -1): if not np.isnan(arr[i]): return arr[i] return np.nan @register_jitted(cache=True) def first_index_reduce_nb(arr: tp.Array1d) -> int: """Get index of first non-NA element.""" if arr.shape[0] == 0: raise ValueError("index is out of bounds") for i in range(len(arr)): if not np.isnan(arr[i]): return i return -1 @register_jitted(cache=True) def last_index_reduce_nb(arr: tp.Array1d) -> int: """Get index of last non-NA element.""" if arr.shape[0] == 0: raise ValueError("index is out of bounds") for i in range(len(arr) - 1, -1, -1): if not np.isnan(arr[i]): return i return -1 @register_jitted(cache=True) def nth_index_reduce_nb(arr: tp.Array1d, n: int) -> int: """Get index of n-th element including NA elements.""" if (n < 0 and abs(n) > arr.shape[0]) or n >= arr.shape[0]: raise ValueError("index is out of bounds") if n >= 0: return n return arr.shape[0] + n @register_jitted(cache=True) def any_reduce_nb(arr: tp.Array1d) -> bool: """Get whether any of the elements are True.""" return np.any(arr) @register_jitted(cache=True) def all_reduce_nb(arr: tp.Array1d) -> bool: """Get whether all of the elements are True.""" return np.all(arr) @register_jitted(cache=True) def min_reduce_nb(arr: tp.Array1d) -> float: """Get min. Ignores NaN.""" return np.nanmin(arr) @register_jitted(cache=True) def max_reduce_nb(arr: tp.Array1d) -> float: """Get max. Ignores NaN.""" return np.nanmax(arr) @register_jitted(cache=True) def mean_reduce_nb(arr: tp.Array1d) -> float: """Get mean. Ignores NaN.""" return np.nanmean(arr) @register_jitted(cache=True) def median_reduce_nb(arr: tp.Array1d) -> float: """Get median. Ignores NaN.""" return np.nanmedian(arr) @register_jitted(cache=True) def std_reduce_nb(arr: tp.Array1d, ddof) -> float: """Get std. Ignores NaN.""" return nanstd_1d_nb(arr, ddof=ddof) @register_jitted(cache=True) def sum_reduce_nb(arr: tp.Array1d) -> float: """Get sum. Ignores NaN.""" return np.nansum(arr) @register_jitted(cache=True) def prod_reduce_nb(arr: tp.Array1d) -> float: """Get product. Ignores NaN.""" return np.nanprod(arr) @register_jitted(cache=True) def nonzero_prod_reduce_nb(arr: tp.Array1d) -> float: """Get product. Ignores zero and NaN. Default value is zero.""" prod = 0.0 for i in range(len(arr)): if not np.isnan(arr[i]) and arr[i] != 0: if prod == 0: prod = 1.0 prod *= arr[i] return prod @register_jitted(cache=True) def count_reduce_nb(arr: tp.Array1d) -> int: """Get count. Ignores NaN.""" return np.sum(~np.isnan(arr)) @register_jitted(cache=True) def argmin_reduce_nb(arr: tp.Array1d) -> int: """Get position of min.""" arr = np.copy(arr) mask = np.isnan(arr) if np.all(mask): raise ValueError("All-NaN slice encountered") arr[mask] = np.inf return np.argmin(arr) @register_jitted(cache=True) def argmax_reduce_nb(arr: tp.Array1d) -> int: """Get position of max.""" arr = np.copy(arr) mask = np.isnan(arr) if np.all(mask): raise ValueError("All-NaN slice encountered") arr[mask] = -np.inf return np.argmax(arr) @register_jitted(cache=True) def describe_reduce_nb(arr: tp.Array1d, perc: tp.Array1d, ddof: int) -> tp.Array1d: """Get descriptive statistics. Ignores NaN. Numba equivalent to `pd.Series(arr).describe(perc)`.""" arr = arr[~np.isnan(arr)] out = np.empty(5 + len(perc), dtype=float_) out[0] = len(arr) if len(arr) > 0: out[1] = np.mean(arr) out[2] = nanstd_1d_nb(arr, ddof=ddof) out[3] = np.min(arr) out[4:-1] = np.percentile(arr, perc * 100) out[4 + len(perc)] = np.max(arr) else: out[1:] = np.nan return out @register_jitted(cache=True) def cov_reduce_grouped_meta_nb( group_idxs: tp.GroupIdxs, group: int, arr1: tp.Array2d, arr2: tp.Array2d, ddof: int, ) -> float: """Get correlation coefficient. Ignores NaN.""" return nancov_1d_nb(arr1[:, group_idxs].flatten(), arr2[:, group_idxs].flatten(), ddof=ddof) @register_jitted(cache=True) def corr_reduce_grouped_meta_nb(group_idxs: tp.GroupIdxs, group: int, arr1: tp.Array2d, arr2: tp.Array2d) -> float: """Get correlation coefficient. Ignores NaN.""" return nancorr_1d_nb(arr1[:, group_idxs].flatten(), arr2[:, group_idxs].flatten()) @register_jitted(cache=True) def wmean_range_reduce_meta_nb(from_i: int, to_i: int, col: int, arr1: tp.Array2d, arr2: tp.Array2d) -> float: """Get the weighted average.""" nom_cumsum = 0 denum_cumsum = 0 for i in range(from_i, to_i): nom_cumsum += arr1[i, col] * arr2[i, col] denum_cumsum += arr2[i, col] if denum_cumsum == 0: return np.nan return nom_cumsum / denum_cumsum # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for base operations.""" import numpy as np from numba import prange from numba.core.types import Type, Omitted from numba.extending import overload from numba.np.numpy_support import as_dtype from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_nb, flex_select_col_nb from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch def _select_indices_1d_nb(arr, indices, fill_value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(fill_value) else: a_dtype = arr.dtype value_dtype = np.array(fill_value).dtype dtype = np.promote_types(a_dtype, value_dtype) def impl(arr, indices, fill_value): out = np.empty(indices.shape, dtype=dtype) for i in range(indices.shape[0]): if 0 <= indices[i] <= arr.shape[0] - 1: out[i] = arr[indices[i]] else: out[i] = fill_value return out if not nb_enabled: return impl(arr, indices, fill_value) return impl overload(_select_indices_1d_nb)(_select_indices_1d_nb) @register_jitted(cache=True) def select_indices_1d_nb(arr: tp.Array1d, indices: tp.Array1d, fill_value: tp.Scalar) -> tp.Array1d: """Set each element to a value by boolean mask.""" return _select_indices_1d_nb(arr, indices, fill_value) def _select_indices_nb(arr, indices, fill_value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(fill_value) else: a_dtype = arr.dtype value_dtype = np.array(fill_value).dtype dtype = np.promote_types(a_dtype, value_dtype) def impl(arr, indices, fill_value): out = np.empty(indices.shape, dtype=dtype) for col in range(indices.shape[1]): for i in range(indices.shape[0]): if 0 <= indices[i, col] <= arr.shape[0] - 1: out[i, col] = arr[indices[i, col], col] else: out[i, col] = fill_value return out if not nb_enabled: return impl(arr, indices, fill_value) return impl overload(_select_indices_nb)(_select_indices_nb) @register_jitted(cache=True) def select_indices_nb(arr: tp.Array2d, indices: tp.Array2d, fill_value: tp.Scalar) -> tp.Array2d: """2-dim version of `select_indices_1d_nb`.""" return _select_indices_nb(arr, indices, fill_value) @register_jitted(cache=True) def shuffle_1d_nb(arr: tp.Array1d, seed: tp.Optional[int] = None) -> tp.Array1d: """Shuffle each column in the array. Specify `seed` to make output deterministic.""" if seed is not None: np.random.seed(seed) return np.random.permutation(arr) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), seed=None), merge_func="column_stack", ) @register_jitted(cache=True) def shuffle_nb(arr: tp.Array2d, seed: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `shuffle_1d_nb`.""" if seed is not None: np.random.seed(seed) out = np.empty_like(arr, dtype=arr.dtype) for col in range(arr.shape[1]): out[:, col] = np.random.permutation(arr[:, col]) return out def _set_by_mask_1d_nb(arr, mask, value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(value) else: a_dtype = arr.dtype value_dtype = np.array(value).dtype dtype = np.promote_types(a_dtype, value_dtype) def impl(arr, mask, value): out = arr.astype(dtype) out[mask] = value return out if not nb_enabled: return impl(arr, mask, value) return impl overload(_set_by_mask_1d_nb)(_set_by_mask_1d_nb) @register_jitted(cache=True) def set_by_mask_1d_nb(arr: tp.Array1d, mask: tp.Array1d, value: tp.Scalar) -> tp.Array1d: """Set each element to a value by boolean mask.""" return _set_by_mask_1d_nb(arr, mask, value) def _set_by_mask_nb(arr, mask, value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(value) else: a_dtype = arr.dtype value_dtype = np.array(value).dtype dtype = np.promote_types(a_dtype, value_dtype) def impl(arr, mask, value): out = arr.astype(dtype) for col in range(arr.shape[1]): out[mask[:, col], col] = value return out if not nb_enabled: return impl(arr, mask, value) return impl overload(_set_by_mask_nb)(_set_by_mask_nb) @register_jitted(cache=True) def set_by_mask_nb(arr: tp.Array2d, mask: tp.Array2d, value: tp.Scalar) -> tp.Array2d: """2-dim version of `set_by_mask_1d_nb`.""" return _set_by_mask_nb(arr, mask, value) def _set_by_mask_mult_1d_nb(arr, mask, values): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(values.dtype) else: a_dtype = arr.dtype value_dtype = values.dtype dtype = np.promote_types(a_dtype, value_dtype) def impl(arr, mask, values): out = arr.astype(dtype) out[mask] = values[mask] return out if not nb_enabled: return impl(arr, mask, values) return impl overload(_set_by_mask_mult_1d_nb)(_set_by_mask_mult_1d_nb) @register_jitted(cache=True) def set_by_mask_mult_1d_nb(arr: tp.Array1d, mask: tp.Array1d, values: tp.Array1d) -> tp.Array1d: """Set each element in one array to the corresponding element in another by boolean mask. `values` must be of the same shape as in the array.""" return _set_by_mask_mult_1d_nb(arr, mask, values) def _set_by_mask_mult_nb(arr, mask, values): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(values.dtype) else: a_dtype = arr.dtype value_dtype = values.dtype dtype = np.promote_types(a_dtype, value_dtype) def impl(arr, mask, values): out = arr.astype(dtype) for col in range(arr.shape[1]): out[mask[:, col], col] = values[mask[:, col], col] return out if not nb_enabled: return impl(arr, mask, values) return impl overload(_set_by_mask_mult_nb)(_set_by_mask_mult_nb) @register_jitted(cache=True) def set_by_mask_mult_nb(arr: tp.Array2d, mask: tp.Array2d, values: tp.Array2d) -> tp.Array2d: """2-dim version of `set_by_mask_mult_1d_nb`.""" return _set_by_mask_mult_nb(arr, mask, values) @register_jitted(cache=True) def first_valid_index_1d_nb(arr: tp.Array1d, check_inf: bool = True) -> int: """Get the index of the first valid value.""" for i in range(arr.shape[0]): if not np.isnan(arr[i]) and (not check_inf or not np.isinf(arr[i])): return i return -1 @register_jitted(cache=True) def first_valid_index_nb(arr, check_inf: bool = True): """2-dim version of `first_valid_index_1d_nb`.""" out = np.empty(arr.shape[1], dtype=int_) for col in range(arr.shape[1]): out[col] = first_valid_index_1d_nb(arr[:, col], check_inf=check_inf) return out @register_jitted(cache=True) def last_valid_index_1d_nb(arr: tp.Array1d, check_inf: bool = True) -> int: """Get the index of the last valid value.""" for i in range(arr.shape[0] - 1, -1, -1): if not np.isnan(arr[i]) and (not check_inf or not np.isinf(arr[i])): return i return -1 @register_jitted(cache=True) def last_valid_index_nb(arr, check_inf: bool = True): """2-dim version of `last_valid_index_1d_nb`.""" out = np.empty(arr.shape[1], dtype=int_) for col in range(arr.shape[1]): out[col] = last_valid_index_1d_nb(arr[:, col], check_inf=check_inf) return out @register_jitted(cache=True) def fillna_1d_nb(arr: tp.Array1d, value: tp.Scalar) -> tp.Array1d: """Replace NaNs with value. Numba equivalent to `pd.Series(arr).fillna(value)`.""" return set_by_mask_1d_nb(arr, np.isnan(arr), value) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), value=None), merge_func="column_stack", ) @register_jitted(cache=True) def fillna_nb(arr: tp.Array2d, value: tp.Scalar) -> tp.Array2d: """2-dim version of `fillna_1d_nb`.""" return set_by_mask_nb(arr, np.isnan(arr), value) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def fbfill_nb(arr: tp.Array2d) -> tp.Array2d: """Forward and backward fill NaN values. !!! note If there are no NaN (or any) values, will return `arr`.""" if arr.size == 0: return arr need_fbfill = False for col in range(arr.shape[1]): for i in range(arr.shape[0]): if np.isnan(arr[i, col]): need_fbfill = True break if need_fbfill: break if not need_fbfill: return arr out = np.empty_like(arr) for col in prange(arr.shape[1]): last_valid = np.nan for i in range(arr.shape[0]): if not np.isnan(arr[i, col]): last_valid = arr[i, col] out[i, col] = last_valid if np.isnan(out[0, col]): last_valid = np.nan for i in range(arr.shape[0] - 1, -1, -1): if not np.isnan(arr[i, col]): last_valid = arr[i, col] if np.isnan(out[i, col]): out[i, col] = last_valid return out def _bshift_1d_nb(arr, n, fill_value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) if isinstance(fill_value, Omitted): fill_value_dtype = np.asarray(fill_value.value).dtype else: fill_value_dtype = as_dtype(fill_value) else: a_dtype = arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(a_dtype, fill_value_dtype) def impl(arr, n, fill_value): out = np.empty(arr.shape[0], dtype=dtype) for i in range(out.shape[0]): if i + n <= out.shape[0] - 1: out[i] = arr[i + n] else: out[i] = fill_value return out if not nb_enabled: return impl(arr, n, fill_value) return impl overload(_bshift_1d_nb)(_bshift_1d_nb) @register_jitted(cache=True) def bshift_1d_nb(arr: tp.Array1d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array1d: """Shift backward by `n` positions. Numba equivalent to `pd.Series(arr).shift(-n)`. !!! warning This operation looks ahead.""" return _bshift_1d_nb(arr, n, fill_value) def _bshift_nb(arr, n, fill_value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) if isinstance(fill_value, Omitted): fill_value_dtype = np.asarray(fill_value.value).dtype else: fill_value_dtype = as_dtype(fill_value) else: a_dtype = arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(a_dtype, fill_value_dtype) def impl(arr, n, fill_value): out = np.empty_like(arr, dtype=dtype) for col in range(arr.shape[1]): out[:, col] = bshift_1d_nb(arr[:, col], n=n, fill_value=fill_value) return out if not nb_enabled: return impl(arr, n, fill_value) return impl overload(_bshift_nb)(_bshift_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None, fill_value=None), merge_func="column_stack", ) @register_jitted(cache=True) def bshift_nb(arr: tp.Array2d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array2d: """2-dim version of `bshift_1d_nb`.""" return _bshift_nb(arr, n, fill_value) def _fshift_1d_nb(arr, n, fill_value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) if isinstance(fill_value, Omitted): fill_value_dtype = np.asarray(fill_value.value).dtype else: fill_value_dtype = as_dtype(fill_value) else: a_dtype = arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(a_dtype, fill_value_dtype) def impl(arr, n, fill_value): out = np.empty(arr.shape[0], dtype=dtype) for i in range(out.shape[0]): if i - n >= 0: out[i] = arr[i - n] else: out[i] = fill_value return out if not nb_enabled: return impl(arr, n, fill_value) return impl overload(_fshift_1d_nb)(_fshift_1d_nb) @register_jitted(cache=True) def fshift_1d_nb(arr: tp.Array1d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array1d: """Shift forward by `n` positions. Numba equivalent to `pd.Series(arr).shift(n)`.""" return _fshift_1d_nb(arr, n, fill_value) def _fshift_nb(arr, n, fill_value): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) if isinstance(fill_value, Omitted): fill_value_dtype = np.asarray(fill_value.value).dtype else: fill_value_dtype = as_dtype(fill_value) else: a_dtype = arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(a_dtype, fill_value_dtype) def impl(arr, n, fill_value): out = np.empty_like(arr, dtype=dtype) for col in range(arr.shape[1]): out[:, col] = fshift_1d_nb(arr[:, col], n=n, fill_value=fill_value) return out if not nb_enabled: return impl(arr, n, fill_value) return impl overload(_fshift_nb)(_fshift_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None, fill_value=None), merge_func="column_stack", ) @register_jitted(cache=True) def fshift_nb(arr: tp.Array2d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array2d: """2-dim version of `fshift_1d_nb`.""" return _fshift_nb(arr, n, fill_value) @register_jitted(cache=True) def diff_1d_nb(arr: tp.Array1d, n: int = 1) -> tp.Array1d: """Compute the 1-th discrete difference. Numba equivalent to `pd.Series(arr).diff()`.""" out = np.empty(arr.shape[0], dtype=float_) for i in range(out.shape[0]): if i - n >= 0: out[i] = arr[i] - arr[i - n] else: out[i] = np.nan return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def diff_nb(arr: tp.Array2d, n: int = 1) -> tp.Array2d: """2-dim version of `diff_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = diff_1d_nb(arr[:, col], n=n) return out @register_jitted(cache=True) def pct_change_1d_nb(arr: tp.Array1d, n: int = 1) -> tp.Array1d: """Compute the percentage change. Numba equivalent to `pd.Series(arr).pct_change()`.""" out = np.empty(arr.shape[0], dtype=float_) for i in range(out.shape[0]): if i - n >= 0: out[i] = (arr[i] - arr[i - n]) / arr[i - n] else: out[i] = np.nan return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def pct_change_nb(arr: tp.Array2d, n: int = 1) -> tp.Array2d: """2-dim version of `pct_change_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = pct_change_1d_nb(arr[:, col], n=n) return out @register_jitted(cache=True) def bfill_1d_nb(arr: tp.Array1d) -> tp.Array1d: """Fill NaNs by propagating first valid observation backward. Numba equivalent to `pd.Series(arr).fillna(method='bfill')`. !!! warning This operation looks ahead.""" out = np.empty_like(arr, dtype=arr.dtype) lastval = arr[-1] for i in range(arr.shape[0] - 1, -1, -1): if np.isnan(arr[i]): out[i] = lastval else: lastval = out[i] = arr[i] return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def bfill_nb(arr: tp.Array2d) -> tp.Array2d: """2-dim version of `bfill_1d_nb`.""" out = np.empty_like(arr, dtype=arr.dtype) for col in prange(arr.shape[1]): out[:, col] = bfill_1d_nb(arr[:, col]) return out @register_jitted(cache=True) def ffill_1d_nb(arr: tp.Array1d) -> tp.Array1d: """Fill NaNs by propagating last valid observation forward. Numba equivalent to `pd.Series(arr).fillna(method='ffill')`.""" out = np.empty_like(arr, dtype=arr.dtype) lastval = arr[0] for i in range(arr.shape[0]): if np.isnan(arr[i]): out[i] = lastval else: lastval = out[i] = arr[i] return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ffill_nb(arr: tp.Array2d) -> tp.Array2d: """2-dim version of `ffill_1d_nb`.""" out = np.empty_like(arr, dtype=arr.dtype) for col in prange(arr.shape[1]): out[:, col] = ffill_1d_nb(arr[:, col]) return out def _nanprod_nb(arr): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) else: a_dtype = arr.dtype dtype = np.promote_types(a_dtype, int) def impl(arr): out = np.empty(arr.shape[1], dtype=dtype) for col in prange(arr.shape[1]): out[col] = np.nanprod(arr[:, col]) return out if not nb_enabled: return impl(arr) return impl overload(_nanprod_nb)(_nanprod_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True) def nanprod_nb(arr: tp.Array2d) -> tp.Array1d: """Numba equivalent of `np.nanprod` along axis 0.""" return _nanprod_nb(arr) def _nancumsum_nb(arr): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) else: a_dtype = arr.dtype dtype = np.promote_types(a_dtype, int) def impl(arr): out = np.empty(arr.shape, dtype=dtype) for col in prange(arr.shape[1]): out[:, col] = np.nancumsum(arr[:, col]) return out if not nb_enabled: return impl(arr) return impl overload(_nancumsum_nb)(_nancumsum_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="column_stack", ) @register_jitted(cache=True) def nancumsum_nb(arr: tp.Array2d) -> tp.Array2d: """Numba equivalent of `np.nancumsum` along axis 0.""" return _nancumsum_nb(arr) def _nancumprod_nb(arr): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) else: a_dtype = arr.dtype dtype = np.promote_types(a_dtype, int) def impl(arr): out = np.empty(arr.shape, dtype=dtype) for col in prange(arr.shape[1]): out[:, col] = np.nancumprod(arr[:, col]) return out if not nb_enabled: return impl(arr) return impl overload(_nancumprod_nb)(_nancumprod_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="column_stack", ) @register_jitted(cache=True) def nancumprod_nb(arr: tp.Array2d) -> tp.Array2d: """Numba equivalent of `np.nancumprod` along axis 0.""" return _nancumprod_nb(arr) def _nansum_nb(arr): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) else: a_dtype = arr.dtype dtype = np.promote_types(a_dtype, int) def impl(arr): out = np.empty(arr.shape[1], dtype=dtype) for col in prange(arr.shape[1]): out[col] = np.nansum(arr[:, col]) return out if not nb_enabled: return impl(arr) return impl overload(_nansum_nb)(_nansum_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True) def nansum_nb(arr: tp.Array2d) -> tp.Array1d: """Numba equivalent of `np.nansum` along axis 0.""" return _nansum_nb(arr) @register_jitted(cache=True) def nancnt_1d_nb(arr: tp.Array1d) -> int: """Compute count while ignoring NaNs and not allocating any arrays.""" cnt = 0 for i in range(arr.shape[0]): if not np.isnan(arr[i]): cnt += 1 return cnt @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nancnt_nb(arr: tp.Array2d) -> tp.Array1d: """2-dim version of `nancnt_1d_nb`.""" out = np.empty(arr.shape[1], dtype=int_) for col in prange(arr.shape[1]): out[col] = nancnt_1d_nb(arr[:, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nanmin_nb(arr: tp.Array2d) -> tp.Array1d: """Numba equivalent of `np.nanmin` along axis 0.""" out = np.empty(arr.shape[1], dtype=arr.dtype) for col in prange(arr.shape[1]): out[col] = np.nanmin(arr[:, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nanmax_nb(arr: tp.Array2d) -> tp.Array1d: """Numba equivalent of `np.nanmax` along axis 0.""" out = np.empty(arr.shape[1], dtype=arr.dtype) for col in prange(arr.shape[1]): out[col] = np.nanmax(arr[:, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nanmean_nb(arr: tp.Array2d) -> tp.Array1d: """Numba equivalent of `np.nanmean` along axis 0.""" out = np.empty(arr.shape[1], dtype=float_) for col in prange(arr.shape[1]): out[col] = np.nanmean(arr[:, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nanmedian_nb(arr: tp.Array2d) -> tp.Array1d: """Numba equivalent of `np.nanmedian` along axis 0.""" out = np.empty(arr.shape[1], dtype=float_) for col in prange(arr.shape[1]): out[col] = np.nanmedian(arr[:, col]) return out @register_jitted(cache=True) def nanpercentile_noarr_1d_nb(arr: tp.Array1d, q: float) -> float: """Numba equivalent of `np.nanpercentile` that does not allocate any arrays. !!! note Has worst case time complexity of O(N^2), which makes it much slower than `np.nanpercentile`, but still faster if used in rolling calculations, especially for `q` near 0 and 100.""" if q < 0: q = 0 elif q > 100: q = 100 do_min = q < 50 if not do_min: q = 100 - q cnt = arr.shape[0] for i in range(arr.shape[0]): if np.isnan(arr[i]): cnt -= 1 if cnt == 0: return np.nan nth_float = q / 100 * (cnt - 1) if nth_float % 1 == 0: nth1 = nth2 = int(nth_float) else: nth1 = int(nth_float) nth2 = nth1 + 1 found1 = np.nan found2 = np.nan k = 0 if do_min: prev_val = -np.inf else: prev_val = np.inf while True: n_same = 0 if do_min: curr_val = np.inf for i in range(arr.shape[0]): if not np.isnan(arr[i]): if arr[i] > prev_val: if arr[i] < curr_val: curr_val = arr[i] n_same = 0 if arr[i] == curr_val: n_same += 1 else: curr_val = -np.inf for i in range(arr.shape[0]): if not np.isnan(arr[i]): if arr[i] < prev_val: if arr[i] > curr_val: curr_val = arr[i] n_same = 0 if arr[i] == curr_val: n_same += 1 prev_val = curr_val k += n_same if np.isnan(found1) and k >= nth1 + 1: found1 = curr_val if np.isnan(found2) and k >= nth2 + 1: found2 = curr_val break if found1 == found2: return found1 factor = (nth_float - nth1) / (nth2 - nth1) return factor * (found2 - found1) + found1 @register_jitted(cache=True) def nanpartition_mean_noarr_1d_nb(arr: tp.Array1d, q: float) -> float: """Average of `np.partition` that ignores NaN values and does not allocate any arrays. !!! note Has worst case time complexity of O(N^2), which makes it much slower than `np.partition`, but still faster if used in rolling calculations, especially for `q` near 0.""" if q < 0: q = 0 elif q > 100: q = 100 cnt = arr.shape[0] for i in range(arr.shape[0]): if np.isnan(arr[i]): cnt -= 1 if cnt == 0: return np.nan nth = int(q / 100 * (cnt - 1)) prev_val = -np.inf partition_sum = 0.0 partition_cnt = 0 k = 0 while True: n_same = 0 curr_val = np.inf for i in range(arr.shape[0]): if not np.isnan(arr[i]): if arr[i] > prev_val: if arr[i] < curr_val: curr_val = arr[i] n_same = 0 if arr[i] == curr_val: n_same += 1 if k + n_same >= nth + 1: partition_sum += (nth + 1 - k) * curr_val partition_cnt += nth + 1 - k break else: partition_sum += n_same * curr_val partition_cnt += n_same prev_val = curr_val k += n_same return partition_sum / partition_cnt @register_jitted(cache=True) def nanvar_1d_nb(arr: tp.Array1d, ddof: int = 0) -> float: """Numba equivalent of `np.nanvar` that does not allocate any arrays.""" cnt = arr.shape[0] for i in range(arr.shape[0]): if np.isnan(arr[i]): cnt -= 1 rcount = max(cnt - ddof, 0) if rcount == 0: return np.nan out = 0.0 a_mean = np.nanmean(arr) for i in range(len(arr)): if not np.isnan(arr[i]): out += abs(arr[i] - a_mean) ** 2 return out / rcount @register_jitted(cache=True) def nanstd_1d_nb(arr: tp.Array1d, ddof: int = 0) -> float: """Numba equivalent of `np.nanstd`.""" return np.sqrt(nanvar_1d_nb(arr, ddof=ddof)) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), ddof=None), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nanstd_nb(arr: tp.Array2d, ddof: int = 0) -> tp.Array1d: """2-dim version of `nanstd_1d_nb`.""" out = np.empty(arr.shape[1], dtype=float_) for col in prange(arr.shape[1]): out[col] = nanstd_1d_nb(arr[:, col], ddof=ddof) return out @register_jitted(cache=True) def nancov_1d_nb(arr1: tp.Array1d, arr2: tp.Array1d, ddof: int = 0) -> float: """Numba equivalent of `np.cov` that ignores NaN values.""" if len(arr1) != len(arr2): raise ValueError("Arrays must have the same length") arr1_sum = 0.0 arr2_sum = 0.0 arr12_sumprod = 0.0 k = 0 for i in range(arr1.shape[0]): if not np.isnan(arr1[i]) and not np.isnan(arr2[i]): arr1_sum += arr1[i] arr2_sum += arr2[i] arr12_sumprod += arr1[i] * arr2[i] k += 1 if k - ddof <= 0: return np.nan arr1_mean = arr1_sum / k arr2_mean = arr2_sum / k return (arr12_sumprod - k * arr1_mean * arr2_mean) / (k - ddof) @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), ddof=None), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nancov_nb(arr1: tp.Array2d, arr2: tp.Array2d, ddof: int = 0) -> tp.Array1d: """2-dim version of `nancov_1d_nb`.""" out = np.empty(arr1.shape[1], dtype=float_) for col in prange(arr1.shape[1]): out[col] = nancov_1d_nb(arr1[:, col], arr2[:, col], ddof=ddof) return out @register_jitted(cache=True) def nancorr_1d_nb(arr1: tp.Array1d, arr2: tp.Array1d) -> float: """Numba equivalent of `np.corrcoef` that ignores NaN values. Numerically stable.""" if len(arr1) != len(arr2): raise ValueError("Arrays must have the same length") arr1_sum = 0.0 arr2_sum = 0.0 arr1_sumsq = 0.0 arr2_sumsq = 0.0 arr12_sumprod = 0.0 k = 0 for i in range(arr1.shape[0]): if not np.isnan(arr1[i]) and not np.isnan(arr2[i]): arr1_sum += arr1[i] arr2_sum += arr2[i] arr1_sumsq += float(arr1[i]) ** 2 arr2_sumsq += float(arr2[i]) ** 2 arr12_sumprod += arr1[i] * arr2[i] k += 1 if k == 0: return np.nan arr1_mean = arr1_sum / k arr2_mean = arr2_sum / k arr1_meansq = arr1_sumsq / k arr2_meansq = arr2_sumsq / k arr12_meanprod = arr12_sumprod / k num = arr12_meanprod - arr1_mean * arr2_mean denom = np.sqrt((arr1_meansq - arr1_mean**2) * (arr2_meansq - arr2_mean**2)) if denom == 0: return np.nan return num / denom @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nancorr_nb(arr1: tp.Array2d, arr2: tp.Array2d) -> tp.Array1d: """2-dim version of `nancorr_1d_nb`.""" out = np.empty(arr1.shape[1], dtype=float_) for col in prange(arr1.shape[1]): out[col] = nancorr_1d_nb(arr1[:, col], arr2[:, col]) return out @register_jitted(cache=True) def rank_1d_nb(arr: tp.Array1d, argsorted: tp.Optional[tp.Array1d] = None, pct: bool = False) -> tp.Array1d: """Compute numerical data ranks. Numba equivalent to `pd.Series(arr).rank(pct=pct)`.""" if argsorted is None: argsorted = np.argsort(arr) out = np.empty_like(arr, dtype=float_) rank_sum = 0 rank_cnt = 0 nan_cnt = 0 for i in range(arr.shape[0]): if np.isnan(arr[i]): nan_cnt += 1 if nan_cnt == arr.shape[0]: out[:] = np.nan return out valid_cnt = out.shape[0] - nan_cnt for i in range(argsorted.shape[0]): rank = i + 1 if np.isnan(arr[argsorted[i]]): out[argsorted[i]] = np.nan elif i < out.shape[0] - 1 and arr[argsorted[i]] == arr[argsorted[i + 1]]: rank_sum += rank rank_cnt += 1 if pct: v = rank / valid_cnt else: v = rank out[argsorted[i]] = v elif rank_sum > 0: rank_sum += rank rank_cnt += 1 if pct: v = rank_sum / rank_cnt / valid_cnt else: v = rank_sum / rank_cnt out[argsorted[i - rank_cnt + 1 : i + 1]] = v rank_sum = 0 rank_cnt = 0 else: if pct: v = rank / valid_cnt else: v = rank out[argsorted[i]] = v return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), argsorted=ch.ArraySlicer(axis=1), pct=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rank_nb(arr: tp.Array2d, argsorted: tp.Optional[tp.Array2d] = None, pct: bool = False) -> tp.Array2d: """2-dim version of `rank_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): if argsorted is None: out[:, col] = rank_1d_nb(arr[:, col], argsorted=None, pct=pct) else: out[:, col] = rank_1d_nb(arr[:, col], argsorted=argsorted[:, col], pct=pct) return out @register_jitted(cache=True) def polyfit_1d_nb(x: tp.Array1d, y: tp.Array1d, deg: int, stabilize: bool = False) -> tp.Array1d: """Compute the least squares polynomial fit.""" if stabilize: mat_ = np.ones(shape=(x.shape[0], deg + 1)) mat_[:, 1] = x if deg > 1: for n in range(2, deg + 1): mat_[:, n] = mat_[:, n - 1] * x scale_vect = np.empty((deg + 1,), dtype=float_) for n in range(0, deg + 1): col_norm = np.linalg.norm(mat_[:, n]) scale_vect[n] = col_norm mat_[:, n] /= col_norm det_ = np.linalg.lstsq(mat_, y, rcond=-1)[0] / scale_vect else: mat_ = np.zeros(shape=(x.shape[0], deg + 1)) const = np.ones_like(x) mat_[:, 0] = const mat_[:, 1] = x if deg > 1: for n in range(2, deg + 1): mat_[:, n] = x**n det_ = np.linalg.lstsq(mat_, y, rcond=-1)[0] return det_[::-1] @register_jitted(cache=True) def fir_filter_1d_nb(b: tp.Array1d, x: tp.Array1d) -> tp.Array1d: """Filter data along one-dimension with an FIR filter.""" n = len(x) m = len(b) y = np.zeros(n) for i in range(n): for j in range(m): if i - j >= 0: y[i] += b[j] * x[i - j] return y # ############# Value counts ############# # @register_chunkable( size=ch.ArraySizer(arg_query="codes", axis=1), arg_take_spec=dict( codes=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), n_uniques=None, group_map=base_ch.GroupMapSlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def value_counts_nb(codes: tp.Array2d, n_uniques: int, group_map: tp.GroupMap) -> tp.Array2d: """Compute value counts per column/group.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens out = np.full((n_uniques, group_lens.shape[0]), 0, dtype=int_) for group in prange(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for k in range(group_len): col = col_idxs[k] for i in range(codes.shape[0]): out[codes[i, col], group] += 1 return out @register_jitted(cache=True) def value_counts_1d_nb(codes: tp.Array1d, n_uniques: int) -> tp.Array1d: """Compute value counts.""" out = np.full(n_uniques, 0, dtype=int_) for i in range(codes.shape[0]): out[codes[i]] += 1 return out @register_chunkable( size=ch.ArraySizer(arg_query="codes", axis=0), arg_take_spec=dict(codes=ch.ArraySlicer(axis=0), n_uniques=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def value_counts_per_row_nb(codes: tp.Array2d, n_uniques: int) -> tp.Array2d: """Compute value counts per row.""" out = np.empty((n_uniques, codes.shape[0]), dtype=int_) for i in prange(codes.shape[0]): out[:, i] = value_counts_1d_nb(codes[i, :], n_uniques) return out # ############# Repartitioning ############# # @register_jitted(cache=True) def repartition_nb(arr: tp.Array2d, counts: tp.Array1d) -> tp.Array1d: """Repartition a 2-dimensional array into a 1-dimensional by removing empty elements.""" if arr.shape[0] == 0: return arr.flatten() out = np.empty(np.sum(counts), dtype=arr.dtype) j = 0 for col in range(counts.shape[0]): out[j : j + counts[col]] = arr[: counts[col], col] j += counts[col] return out # ############# Crossover ############# # @register_jitted(cache=True) def crossed_above_1d_nb(arr1: tp.Array1d, arr2: tp.FlexArray1dLike, wait: int = 0, dropna: bool = False) -> tp.Array1d: """Get the crossover of the first array going above the second array. If `dropna` is True, produces the same results as if all rows with at least one NaN were dropped.""" arr2_ = to_1d_array_nb(np.asarray(arr2)) out = np.empty(arr1.shape, dtype=np.bool_) was_below = False confirmed = 0 for i in range(arr1.shape[0]): _arr1 = arr1[i] _arr2 = flex_select_1d_nb(arr2_, i) if np.isnan(_arr1) or np.isnan(_arr2): if not dropna: was_below = False confirmed = 0 out[i] = False elif _arr1 > _arr2: if was_below: confirmed += 1 out[i] = confirmed == wait + 1 else: out[i] = False elif _arr1 == _arr2: if confirmed > 0: was_below = False confirmed = 0 out[i] = False elif _arr1 < _arr2: confirmed = 0 was_below = True out[i] = False return out @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict( arr1=ch.ArraySlicer(axis=1), arr2=base_ch.FlexArraySlicer(axis=1), wait=None, dropna=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def crossed_above_nb(arr1: tp.Array2d, arr2: tp.FlexArray2dLike, wait: int = 0, dropna: bool = False) -> tp.Array2d: """2-dim version of `crossed_above_1d_nb`.""" arr2_ = to_2d_array_nb(np.asarray(arr2)) out = np.empty(arr1.shape, dtype=np.bool_) for col in prange(arr1.shape[1]): _arr2 = flex_select_col_nb(arr2_, col) out[:, col] = crossed_above_1d_nb(arr1[:, col], _arr2, wait=wait, dropna=dropna) return out @register_jitted(cache=True) def crossed_below_1d_nb(arr1: tp.Array1d, arr2: tp.FlexArray1dLike, wait: int = 0, dropna: bool = False) -> tp.Array1d: """Get the crossover of the first array going below the second array. If `dropna` is True, produces the same results as if all rows with at least one NaN were dropped.""" arr2_ = to_1d_array_nb(np.asarray(arr2)) out = np.empty(arr1.shape, dtype=np.bool_) was_above = False confirmed = 0 for i in range(arr1.shape[0]): _arr1 = arr1[i] _arr2 = flex_select_1d_nb(arr2_, i) if np.isnan(_arr1) or np.isnan(_arr2): if not dropna: was_above = False confirmed = 0 out[i] = False elif _arr1 < _arr2: if was_above: confirmed += 1 out[i] = confirmed == wait + 1 else: out[i] = False elif _arr1 == _arr2: if confirmed > 0: was_above = False confirmed = 0 out[i] = False elif _arr1 > _arr2: confirmed = 0 was_above = True out[i] = False return out @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict( arr1=ch.ArraySlicer(axis=1), arr2=base_ch.FlexArraySlicer(axis=1), wait=None, dropna=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def crossed_below_nb(arr1: tp.Array2d, arr2: tp.FlexArray2dLike, wait: int = 0, dropna: bool = False) -> tp.Array2d: """2-dim version of `crossed_below_1d_nb`.""" arr2_ = to_2d_array_nb(np.asarray(arr2)) out = np.empty(arr1.shape, dtype=np.bool_) for col in prange(arr1.shape[1]): _arr2 = flex_select_col_nb(arr2_, col) out[:, col] = crossed_below_1d_nb(arr1[:, col], _arr2, wait=wait, dropna=dropna) return out # ############# Transforming ############# # @register_chunkable( size=base_ch.GroupLensSizer(arg_query="group_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper), group_map=base_ch.GroupMapSlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def demean_nb(arr: tp.Array2d, group_map: tp.GroupMap) -> tp.Array2d: """Demean each value within its group.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens out = np.empty_like(arr, dtype=float_) for group in prange(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for i in range(arr.shape[0]): group_sum = 0 group_cnt = 0 for k in range(group_len): col = col_idxs[k] if not np.isnan(arr[i, col]): group_sum += arr[i, col] group_cnt += 1 for k in range(group_len): col = col_idxs[k] if np.isnan(arr[i, col]) or group_cnt == 0: out[i, col] = np.nan else: out[i, col] = arr[i, col] - group_sum / group_cnt return out @register_jitted(cache=True) def to_renko_1d_nb( arr: tp.Array1d, brick_size: tp.FlexArray1dLike, relative: tp.FlexArray1dLike = False, start_value: tp.Optional[float] = None, max_out_len: tp.Optional[int] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]: """Convert to Renko format.""" brick_size_ = to_1d_array_nb(np.asarray(brick_size)) relative_ = to_1d_array_nb(np.asarray(relative)) if max_out_len is None: out_n = arr.shape[0] else: out_n = max_out_len arr_out = np.empty(out_n, dtype=float_) idx_out = np.empty(out_n, dtype=int_) uptrend_out = np.empty(out_n, dtype=np.bool_) prev_value = np.nan k = 0 trend = 0 for i in range(arr.shape[0]): _brick_size = abs(flex_select_1d_nb(brick_size_, i)) _relative = flex_select_1d_nb(relative_, i) curr_value = arr[i] if np.isnan(curr_value): continue if np.isnan(prev_value): if start_value is None: if not _relative: prev_value = curr_value - curr_value % _brick_size else: prev_value = curr_value else: prev_value = start_value continue if _relative: diff = (curr_value - prev_value) / prev_value else: diff = curr_value - prev_value while abs(diff) >= _brick_size: prev_trend = trend if diff >= 0: if _relative: prev_value *= 1 + _brick_size else: prev_value += _brick_size trend = 1 else: if _relative: prev_value *= 1 - _brick_size else: prev_value -= _brick_size trend = -1 if _relative: diff = (curr_value - prev_value) / prev_value else: diff = curr_value - prev_value if trend == -prev_trend: continue if k >= len(arr_out): raise IndexError("Index out of range. Set a higher max_out_len.") arr_out[k] = prev_value idx_out[k] = i uptrend_out[k] = trend == 1 k += 1 return arr_out[:k], idx_out[:k], uptrend_out[:k] @register_jitted(cache=True) def to_renko_ohlc_1d_nb( arr: tp.Array1d, brick_size: tp.FlexArray1dLike, relative: tp.FlexArray1dLike = False, start_value: tp.Optional[float] = None, max_out_len: tp.Optional[int] = None, ) -> tp.Tuple[tp.Array2d, tp.Array1d]: """Convert to Renko OHLC format.""" brick_size_ = to_1d_array_nb(np.asarray(brick_size)) relative_ = to_1d_array_nb(np.asarray(relative)) if max_out_len is None: out_n = arr.shape[0] else: out_n = max_out_len arr_out = np.empty((out_n, 4), dtype=float_) idx_out = np.empty(out_n, dtype=int_) prev_value = np.nan k = 0 trend = 0 for i in range(arr.shape[0]): _brick_size = abs(flex_select_1d_nb(brick_size_, i)) _relative = flex_select_1d_nb(relative_, i) curr_value = arr[i] if np.isnan(curr_value): continue if np.isnan(prev_value): if start_value is None: if not _relative: prev_value = curr_value - curr_value % _brick_size else: prev_value = curr_value else: prev_value = start_value continue if _relative: diff = (curr_value - prev_value) / prev_value else: diff = curr_value - prev_value while abs(diff) >= _brick_size: open_value = prev_value prev_trend = trend if diff >= 0: if _relative: prev_value *= 1 + _brick_size else: prev_value += _brick_size trend = 1 else: if _relative: prev_value *= 1 - _brick_size else: prev_value -= _brick_size trend = -1 if _relative: diff = (curr_value - prev_value) / prev_value else: diff = curr_value - prev_value if trend == -prev_trend: continue if k >= len(arr_out): raise IndexError("Index out of range. Set a higher max_out_len.") if trend == 1: high_value = prev_value low_value = open_value else: high_value = open_value low_value = prev_value close_value = prev_value arr_out[k, 0] = open_value arr_out[k, 1] = high_value arr_out[k, 2] = low_value arr_out[k, 3] = close_value idx_out[k] = i k += 1 return arr_out[:k], idx_out[:k] # ############# Resampling ############# # def _realign_1d_nb( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(nan_value) else: a_dtype = arr.dtype value_dtype = np.array(nan_value).dtype dtype = np.promote_types(a_dtype, value_dtype) def impl( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ): out = np.empty(target_index.shape[0], dtype=dtype) curr_j = -1 last_j = curr_j last_valid = np.nan for i in range(len(target_index)): if i > 0 and target_index[i] < target_index[i - 1]: raise ValueError("Target index must be increasing") target_bound_inf = target_rbound and i == len(target_index) - 1 and target_freq is None last_valid_at_i = np.nan for j in range(curr_j + 1, source_index.shape[0]): if j > 0 and source_index[j] < source_index[j - 1]: raise ValueError("Array index must be increasing") source_bound_inf = source_rbound and j == len(source_index) - 1 and source_freq is None if source_bound_inf and target_bound_inf: curr_j = j if not np.isnan(arr[curr_j]): last_valid_at_i = arr[curr_j] break if source_bound_inf: break if target_bound_inf: curr_j = j if not np.isnan(arr[curr_j]): last_valid_at_i = arr[curr_j] continue if source_rbound and target_rbound: if source_freq is None: source_val = source_index[j + 1] else: source_val = source_index[j] + source_freq if target_freq is None: target_val = target_index[i + 1] else: target_val = target_index[i] + target_freq if source_val > target_val: break elif source_rbound: if source_freq is None: source_val = source_index[j + 1] else: source_val = source_index[j] + source_freq if source_val > target_index[i]: break elif target_rbound: if target_freq is None: target_val = target_index[i + 1] else: target_val = target_index[i] + target_freq if source_index[j] >= target_val: break else: if source_index[j] > target_index[i]: break curr_j = j if not np.isnan(arr[curr_j]): last_valid_at_i = arr[curr_j] if ffill and not np.isnan(last_valid_at_i): last_valid = last_valid_at_i if curr_j == -1 or (not ffill and curr_j == last_j): out[i] = nan_value else: if ffill: if np.isnan(last_valid): out[i] = nan_value else: out[i] = last_valid else: if np.isnan(last_valid_at_i): out[i] = nan_value else: out[i] = last_valid_at_i last_j = curr_j return out if not nb_enabled: return impl( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ) return impl overload(_realign_1d_nb)(_realign_1d_nb) @register_jitted(cache=True) def realign_1d_nb( arr: tp.Array1d, source_index: tp.Array1d, target_index: tp.Array1d, source_freq: tp.Optional[tp.Scalar] = None, target_freq: tp.Optional[tp.Scalar] = None, source_rbound: bool = False, target_rbound: bool = None, nan_value: tp.Scalar = np.nan, ffill: bool = True, ) -> tp.Array1d: """Get the latest in `arr` at each index in `target_index` based on `source_index`. If `source_rbound` is True, then each element in `source_index` is effectively located at the right bound, which is the frequency or the next element (excluding) if the frequency is None. The same for `target_rbound` and `target_index`. !!! note Both index arrays must be increasing. Repeating values are allowed. If `arr` contains bar data, both indexes must represent the opening time.""" return _realign_1d_nb( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ) def _realign_nb( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ): nb_enabled = isinstance(arr, Type) if nb_enabled: a_dtype = as_dtype(arr.dtype) value_dtype = as_dtype(nan_value) else: a_dtype = arr.dtype value_dtype = np.array(nan_value).dtype dtype = np.promote_types(a_dtype, value_dtype) def impl( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ): out = np.empty((target_index.shape[0], arr.shape[1]), dtype=dtype) for col in prange(arr.shape[1]): out[:, col] = realign_1d_nb( arr[:, col], source_index, target_index, source_freq=source_freq, target_freq=target_freq, source_rbound=source_rbound, target_rbound=target_rbound, nan_value=nan_value, ffill=ffill, ) return out if not nb_enabled: return impl( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ) return impl overload(_realign_nb)(_realign_nb) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), source_index=None, target_index=None, source_freq=None, target_freq=None, source_rbound=None, target_rbound=None, nan_value=None, ffill=None, ), merge_func="column_stack", ) @register_jitted(cache=True) def realign_nb( arr: tp.Array2d, source_index: tp.Array1d, target_index: tp.Array1d, source_freq: tp.Optional[tp.Scalar] = None, target_freq: tp.Optional[tp.Scalar] = None, source_rbound: bool = False, target_rbound: bool = False, nan_value: tp.Scalar = np.nan, ffill: bool = True, ) -> tp.Array2d: """2-dim version of `realign_1d_nb`.""" return _realign_nb( arr, source_index, target_index, source_freq, target_freq, source_rbound, target_rbound, nan_value, ffill, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for iterative use.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base.flex_indexing import flex_select_nb from vectorbtpro.registries.jit_registry import register_jitted @register_jitted(cache=True) def iter_above_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool: """Check whether `arr1` is above `arr2` at specific row and column.""" if i < 0: return False arr1_now = flex_select_nb(arr1, i, col) arr2_now = flex_select_nb(arr2, i, col) if np.isnan(arr1_now) or np.isnan(arr2_now): return False return arr1_now > arr2_now @register_jitted(cache=True) def iter_below_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool: """Check whether `arr1` is below `arr2` at specific row and column.""" if i < 0: return False arr1_now = flex_select_nb(arr1, i, col) arr2_now = flex_select_nb(arr2, i, col) if np.isnan(arr1_now) or np.isnan(arr2_now): return False return arr1_now < arr2_now @register_jitted(cache=True) def iter_crossed_above_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool: """Check whether `arr1` crossed above `arr2` at specific row and column.""" if i < 0 or i - 1 < 0: return False arr1_prev = flex_select_nb(arr1, i - 1, col) arr2_prev = flex_select_nb(arr2, i - 1, col) arr1_now = flex_select_nb(arr1, i, col) arr2_now = flex_select_nb(arr2, i, col) if np.isnan(arr1_prev) or np.isnan(arr2_prev) or np.isnan(arr1_now) or np.isnan(arr2_now): return False return arr1_prev < arr2_prev and arr1_now > arr2_now @register_jitted(cache=True) def iter_crossed_below_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool: """Check whether `arr1` crossed below `arr2` at specific row and column.""" if i < 0 or i - 1 < 0: return False arr1_prev = flex_select_nb(arr1, i - 1, col) arr2_prev = flex_select_nb(arr2, i - 1, col) arr1_now = flex_select_nb(arr1, i, col) arr2_now = flex_select_nb(arr2, i, col) if np.isnan(arr1_prev) or np.isnan(arr2_prev) or np.isnan(arr1_now) or np.isnan(arr2_now): return False return arr1_prev > arr2_prev and arr1_now < arr2_now # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for working with patterns.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.flex_indexing import flex_select_1d_nb from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.generic.enums import RescaleMode, InterpMode, ErrorType, DistanceMeasure from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils.array_ import rescale_nb @register_jitted(cache=True) def linear_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float: """Get the value at a specific position in a target size using linear interpolation.""" if i == 0 or source_size == 1 or target_size == 1: return float(flex_select_1d_nb(arr, 0)) if source_size == target_size: return float(flex_select_1d_nb(arr, i)) if i == target_size - 1: return float(flex_select_1d_nb(arr, source_size - 1)) mapped_i = i / (target_size - 1) * (source_size - 1) left_i = int(np.floor(mapped_i)) right_i = int(np.ceil(mapped_i)) norm_mapped_i = mapped_i - left_i left_elem = float(flex_select_1d_nb(arr, left_i)) right_elem = float(flex_select_1d_nb(arr, right_i)) return left_elem + norm_mapped_i * (right_elem - left_elem) @register_jitted(cache=True) def nearest_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float: """Get the value at a specific position in a target size using nearest-neighbor interpolation.""" if i == 0 or source_size == 1 or target_size == 1: return float(flex_select_1d_nb(arr, 0)) if source_size == target_size: return float(flex_select_1d_nb(arr, i)) if i == target_size - 1: return float(flex_select_1d_nb(arr, source_size - 1)) mapped_i = i / (target_size - 1) * (source_size - 1) return float(flex_select_1d_nb(arr, round(mapped_i))) @register_jitted(cache=True) def discrete_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float: """Get the value at a specific position in a target size using discrete interpolation.""" if source_size >= target_size: return nearest_interp_nb(arr, i, source_size, target_size) if i == 0 or source_size == 1 or target_size == 1: return float(flex_select_1d_nb(arr, 0)) if i == target_size - 1: return float(flex_select_1d_nb(arr, source_size - 1)) curr_float_mapped_i = i / (target_size - 1) * (source_size - 1) curr_remainder = curr_float_mapped_i % 1 if curr_remainder == 0: return float(flex_select_1d_nb(arr, int(curr_float_mapped_i))) if curr_remainder <= 0.5: prev_float_mapped_i = (i - 1) / (target_size - 1) * (source_size - 1) if int(curr_float_mapped_i) != int(prev_float_mapped_i): prev_remainder = prev_float_mapped_i % 1 if curr_remainder < 1 - prev_remainder: return float(flex_select_1d_nb(arr, int(np.floor(curr_float_mapped_i)))) return np.nan next_float_mapped_i = (i + 1) / (target_size - 1) * (source_size - 1) if int(curr_float_mapped_i) != int(next_float_mapped_i): next_remainder = next_float_mapped_i % 1 if 1 - curr_remainder <= next_remainder: return float(flex_select_1d_nb(arr, int(np.ceil(curr_float_mapped_i)))) return np.nan @register_jitted(cache=True) def mixed_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float: """Get the value at a specific position in a target size using mixed interpolation. Mixed interpolation is based on the discrete interpolation, while filling resulting NaN values using the linear interpolation. This way, the vertical scale of the pattern array is respected.""" value = discrete_interp_nb(arr, i, source_size, target_size) if np.isnan(value): value = linear_interp_nb(arr, i, source_size, target_size) return value @register_jitted(cache=True) def interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int, interp_mode: int) -> float: """Get the value at a specific position in a target size using an interpolation mode. See `vectorbtpro.generic.enums.InterpMode`.""" if interp_mode == InterpMode.Linear: return linear_interp_nb(arr, i, source_size, target_size) if interp_mode == InterpMode.Nearest: return nearest_interp_nb(arr, i, source_size, target_size) if interp_mode == InterpMode.Discrete: return discrete_interp_nb(arr, i, source_size, target_size) if interp_mode == InterpMode.Mixed: return mixed_interp_nb(arr, i, source_size, target_size) raise ValueError("Invalid interpolation mode") @register_jitted(cache=True) def interp_resize_1d_nb(arr: tp.FlexArray1d, target_size: int, interp_mode: int) -> tp.Array1d: """Resize an array using `interp_nb`.""" out = np.empty(target_size, dtype=float_) for i in range(target_size): out[i] = interp_nb(arr, i, arr.size, target_size, interp_mode) return out @register_jitted(cache=True) def fit_pattern_nb( arr: tp.Array1d, pattern: tp.Array1d, interp_mode: int = InterpMode.Mixed, rescale_mode: int = RescaleMode.MinMax, vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: int = ErrorType.Absolute, max_error: tp.FlexArray1dLike = np.nan, max_error_interp_mode: tp.Optional[int] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Fit pattern. Returns the resized and rescaled pattern and max error.""" max_error_ = to_1d_array_nb(np.asarray(max_error)) if max_error_interp_mode is None: max_error_interp_mode = interp_mode fit_pattern = interp_resize_1d_nb( pattern, len(arr), interp_mode, ) fit_max_error = interp_resize_1d_nb( max_error_, len(arr), max_error_interp_mode, ) if np.isnan(vmin): vmin = np.nanmin(arr) else: vmin = vmin if np.isnan(vmax): vmax = np.nanmax(arr) else: vmax = vmax if np.isnan(pmin): pmin = np.nanmin(fit_pattern) else: pmin = pmin if np.isnan(pmax): pmax = np.nanmax(fit_pattern) else: pmax = pmax if invert: fit_pattern = pmax + pmin - fit_pattern if rescale_mode == RescaleMode.Rebase: if not np.isnan(pmin): if fit_pattern[0] == 0: pmin = np.nan else: pmin = pmin / fit_pattern[0] * arr[0] if not np.isnan(pmax): if fit_pattern[0] == 0: pmax = np.nan else: pmax = pmax / fit_pattern[0] * arr[0] if rescale_mode == RescaleMode.Rebase: if fit_pattern[0] == 0: fit_pattern = np.full(fit_pattern.shape, np.nan) else: fit_pattern = fit_pattern / fit_pattern[0] * arr[0] fit_max_error = fit_max_error * fit_pattern fit_pattern = np.clip(fit_pattern, pmin, pmax) if rescale_mode == RescaleMode.MinMax: fit_pattern = rescale_nb(fit_pattern, (pmin, pmax), (vmin, vmax)) if error_type == ErrorType.Absolute: if pmax - pmin == 0: fit_max_error = np.full(fit_max_error.shape, np.nan) else: fit_max_error = fit_max_error / (pmax - pmin) * (vmax - vmin) else: fit_max_error = fit_max_error * fit_pattern return fit_pattern, fit_max_error @register_jitted(cache=True) def pattern_similarity_nb( arr: tp.Array1d, pattern: tp.Array1d, interp_mode: int = InterpMode.Mixed, rescale_mode: int = RescaleMode.MinMax, vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: int = ErrorType.Absolute, distance_measure: int = DistanceMeasure.MAE, max_error: tp.FlexArray1dLike = np.nan, max_error_interp_mode: tp.Optional[int] = None, max_error_as_maxdist: bool = False, max_error_strict: bool = False, min_pct_change: float = np.nan, max_pct_change: float = np.nan, min_similarity: float = np.nan, minp: tp.Optional[int] = None, ) -> float: """Get the similarity between an array and a pattern array. At each position in the array, the value in `arr` is first mapped into the range of `pattern`. Then, the absolute distance between the actual and expected value is calculated (= error). This error is then divided by the maximum error at this position to get a relative value between 0 and 1. Finally, all relative errors are added together and subtracted from 1 to get the similarity measure. * For `interp_mode`, see `vectorbtpro.generic.enums.InterpMode` * For `rescale_mode`, see `vectorbtpro.generic.enums.RescaleMode` * For `error_type`, see `vectorbtpro.generic.enums.ErrorType` * For `distance_measure`, see `vectorbtpro.generic.enums.DistanceMeasure` """ max_error_ = to_1d_array_nb(np.asarray(max_error)) if len(arr) == 0: return np.nan if len(pattern) == 0: return np.nan if rescale_mode == RescaleMode.Rebase: if np.isnan(pattern[0]): return np.nan if np.isnan(arr[0]): return np.nan if max_error_interp_mode is None or max_error_interp_mode == -1: _max_error_interp_mode = interp_mode else: _max_error_interp_mode = max_error_interp_mode max_size = max(arr.shape[0], pattern.shape[0]) if error_type != ErrorType.Absolute and error_type != ErrorType.Relative: raise ValueError("Invalid error type") if ( distance_measure != DistanceMeasure.MAE and distance_measure != DistanceMeasure.MSE and distance_measure != DistanceMeasure.RMSE ): raise ValueError("Invalid distance mode") if minp is None: minp = arr.shape[0] min_max_required = False if rescale_mode == RescaleMode.MinMax: min_max_required = True if not np.isnan(min_pct_change): min_max_required = True if not np.isnan(max_pct_change): min_max_required = True if not max_error_as_maxdist: min_max_required = True if invert: min_max_required = True if min_max_required: vmin_set = not np.isnan(vmin) vmax_set = not np.isnan(vmax) pmin_set = not np.isnan(pmin) pmax_set = not np.isnan(pmax) if not vmin_set or not vmax_set or not pmin_set or not pmax_set: for i in range(max_size): if arr.shape[0] >= pattern.shape[0]: arr_elem = arr[i] else: arr_elem = interp_nb(arr, i, arr.shape[0], pattern.shape[0], interp_mode) if pattern.shape[0] >= arr.shape[0]: pattern_elem = pattern[i] else: pattern_elem = interp_nb(pattern, i, pattern.shape[0], arr.shape[0], interp_mode) if not np.isnan(arr_elem): if not vmin_set and (np.isnan(vmin) or arr_elem < vmin): vmin = arr_elem if not vmax_set and (np.isnan(vmax) or arr_elem > vmax): vmax = arr_elem if not np.isnan(pattern_elem): if not pmin_set and (np.isnan(pmin) or pattern_elem < pmin): pmin = pattern_elem if not pmax_set and (np.isnan(pmax) or pattern_elem > pmax): pmax = pattern_elem if np.isnan(vmin) or np.isnan(vmax): return np.nan if np.isnan(pmin) or np.isnan(pmax): return np.nan if vmin == vmax and rescale_mode == RescaleMode.MinMax: return np.nan if pmin == pmax and rescale_mode == RescaleMode.MinMax: return np.nan if not np.isnan(min_pct_change) and vmin != 0 and (vmax - vmin) / vmin < min_pct_change: return np.nan if not np.isnan(max_pct_change) and vmin != 0 and (vmax - vmin) / vmin > max_pct_change: return np.nan first_pattern_elem = pattern[0] if invert: first_pattern_elem = pmax + pmin - first_pattern_elem if rescale_mode == RescaleMode.Rebase: if not np.isnan(pmin): if first_pattern_elem == 0: pmin = np.nan else: pmin = pmin / first_pattern_elem * arr[0] if not np.isnan(pmax): if first_pattern_elem == 0: pmax = np.nan else: pmax = pmax / first_pattern_elem * arr[0] if rescale_mode == RescaleMode.Rebase or rescale_mode == RescaleMode.Disable: if not np.isnan(pmin) and not np.isnan(vmin): _min = min(pmin, vmin) else: _min = vmin if not np.isnan(pmax) and not np.isnan(vmax): _max = max(pmax, vmax) else: _max = vmax else: _min = vmin _max = vmax distance_sum = 0.0 maxdistance_sum = 0.0 nan_count = 0 for i in range(max_size): if i < arr.shape[0]: if np.isnan(arr[i]): nan_count += 1 if max_size - nan_count < minp: return np.nan if arr.shape[0] == pattern.shape[0]: arr_elem = arr[i] pattern_elem = pattern[i] _max_error = flex_select_1d_nb(max_error_, i) elif arr.shape[0] > pattern.shape[0]: arr_elem = arr[i] pattern_elem = interp_nb(pattern, i, pattern.shape[0], arr.shape[0], interp_mode) _max_error = interp_nb(max_error_, i, pattern.shape[0], arr.shape[0], _max_error_interp_mode) else: arr_elem = interp_nb(arr, i, arr.shape[0], pattern.shape[0], interp_mode) pattern_elem = pattern[i] _max_error = flex_select_1d_nb(max_error_, i) if not np.isnan(arr_elem) and not np.isnan(pattern_elem): if invert: pattern_elem = pmax + pmin - pattern_elem if rescale_mode == RescaleMode.Rebase: if first_pattern_elem == 0: pattern_elem = np.nan else: pattern_elem = pattern_elem / first_pattern_elem * arr[0] if error_type == ErrorType.Absolute: _max_error = _max_error * pattern_elem if not np.isnan(vmin) and arr_elem < vmin: arr_elem = vmin if not np.isnan(vmax) and arr_elem > vmax: arr_elem = vmax if not np.isnan(pmin) and pattern_elem < pmin: pattern_elem = pmin if not np.isnan(pmax) and pattern_elem > pmax: pattern_elem = pmax if rescale_mode == RescaleMode.MinMax: if pmax - pmin == 0: pattern_elem = np.nan else: pattern_elem = (pattern_elem - pmin) / (pmax - pmin) * (vmax - vmin) + vmin if error_type == ErrorType.Absolute: if pmax - pmin == 0: _max_error = np.nan else: _max_error = _max_error / (pmax - pmin) * (vmax - vmin) if distance_measure == DistanceMeasure.MAE: if error_type == ErrorType.Absolute: dist = abs(arr_elem - pattern_elem) else: if pattern_elem == 0: dist = np.nan else: dist = abs(arr_elem - pattern_elem) / pattern_elem else: if error_type == ErrorType.Absolute: dist = (arr_elem - pattern_elem) ** 2 else: if pattern_elem == 0: dist = np.nan else: dist = ((arr_elem - pattern_elem) / pattern_elem) ** 2 if max_error_as_maxdist: if np.isnan(_max_error): continue maxdist = _max_error else: if distance_measure == DistanceMeasure.MAE: if error_type == ErrorType.Absolute: maxdist = max(pattern_elem - _min, _max - pattern_elem) else: if pattern_elem == 0: maxdist = np.nan else: maxdist = max(pattern_elem - _min, _max - pattern_elem) / pattern_elem else: if error_type == ErrorType.Absolute: maxdist = max(pattern_elem - _min, _max - pattern_elem) ** 2 else: if pattern_elem == 0: maxdist = np.nan else: maxdist = (max(pattern_elem - _min, _max - pattern_elem) / pattern_elem) ** 2 if dist > 0 and maxdist == 0: return np.nan if not np.isnan(_max_error) and dist > _max_error: if max_error_strict: return np.nan dist = maxdist if dist > maxdist: dist = maxdist distance_sum = distance_sum + dist maxdistance_sum = maxdistance_sum + maxdist if not np.isnan(min_similarity): if not max_error_as_maxdist or max_error_.size == 1: if max_error_as_maxdist: if np.isnan(_max_error): return np.nan worst_maxdist = _max_error else: if distance_measure == DistanceMeasure.MAE: if error_type == ErrorType.Absolute: worst_maxdist = _max - _min else: if _min == 0: worst_maxdist = np.nan else: worst_maxdist = (_max - _min) / _min else: if error_type == ErrorType.Absolute: worst_maxdist = (_max - _min) ** 2 else: if _min == 0: worst_maxdist = np.nan else: worst_maxdist = ((_max - _min) / _min) ** 2 worst_maxdistance_sum = maxdistance_sum + worst_maxdist * (max_size - i - 1) if worst_maxdistance_sum == 0: return np.nan if distance_measure == DistanceMeasure.RMSE: if worst_maxdistance_sum == 0: best_similarity = np.nan else: best_similarity = 1 - np.sqrt(distance_sum) / np.sqrt(worst_maxdistance_sum) else: if worst_maxdistance_sum == 0: best_similarity = np.nan else: best_similarity = 1 - distance_sum / worst_maxdistance_sum if best_similarity < min_similarity: return np.nan if distance_sum == 0: return 1.0 if maxdistance_sum == 0: return np.nan if distance_measure == DistanceMeasure.RMSE: if maxdistance_sum == 0: similarity = np.nan else: similarity = 1 - np.sqrt(distance_sum) / np.sqrt(maxdistance_sum) else: if maxdistance_sum == 0: similarity = np.nan else: similarity = 1 - distance_sum / maxdistance_sum if not np.isnan(min_similarity): if similarity < min_similarity: return np.nan return similarity # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for records.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb, flex_select_nb from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.generic.enums import * from vectorbtpro.generic.nb.base import repartition_nb from vectorbtpro.generic.nb.patterns import pattern_similarity_nb from vectorbtpro.generic.nb.sim_range import prepare_sim_range_nb from vectorbtpro.records import chunking as records_ch from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.template import Rep # ############# Ranges ############# # @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), gap_value=None), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def get_ranges_nb(arr: tp.Array2d, gap_value: tp.Scalar) -> tp.RecordArray: """Fill range records between gaps. Usage: * Find ranges in time series: ```pycon >>> from vectorbtpro import * >>> a = np.array([ ... [np.nan, np.nan, np.nan, np.nan], ... [ 2, np.nan, np.nan, np.nan], ... [ 3, 3, np.nan, np.nan], ... [np.nan, 4, 4, np.nan], ... [ 5, np.nan, 5, 5], ... [ 6, 6, np.nan, 6] ... ]) >>> records = vbt.nb.get_ranges_nb(a, np.nan) >>> pd.DataFrame.from_records(records) id col start_idx end_idx status 0 0 0 1 3 1 1 1 0 4 5 0 2 0 1 2 4 1 3 1 1 5 5 0 4 0 2 3 5 1 5 0 3 4 5 0 ``` """ new_records = np.empty(arr.shape, dtype=range_dt) counts = np.full(arr.shape[1], 0, dtype=int_) for col in prange(arr.shape[1]): range_started = False start_idx = -1 end_idx = -1 store_record = False status = -1 for i in range(arr.shape[0]): cur_val = arr[i, col] if cur_val == gap_value or np.isnan(cur_val) and np.isnan(gap_value): if range_started: # If stopped, save the current range end_idx = i range_started = False store_record = True status = RangeStatus.Closed else: if not range_started: # If started, register a new range start_idx = i range_started = True if i == arr.shape[0] - 1 and range_started: # If still running, mark for save end_idx = arr.shape[0] - 1 range_started = False store_record = True status = RangeStatus.Open if store_record: # Save range to the records r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = start_idx new_records["end_idx"][r, col] = end_idx new_records["status"][r, col] = status counts[col] += 1 # Reset running vars for a new range store_record = False return repartition_nb(new_records, counts) @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( n_rows=None, idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), id_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), index=None, delta=None, delta_use_index=None, ), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def get_ranges_from_delta_nb( n_rows: int, idx_arr: tp.Array1d, id_arr: tp.Array1d, col_map: tp.GroupMap, index: tp.Optional[tp.Array1d] = None, delta: int = 0, delta_use_index: bool = False, shift: int = 0, ) -> tp.RecordArray: """Build delta ranges.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.empty(idx_arr.shape[0], dtype=range_dt) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] ridxs = col_idxs[col_start_idx : col_start_idx + col_len] for r in ridxs: r_idx = idx_arr[r] + shift if r_idx < 0: r_idx = 0 if r_idx > n_rows - 1: r_idx = n_rows - 1 if delta >= 0: start_idx = r_idx if delta_use_index: if index is None: raise ValueError("Index is required") end_idx = len(index) - 1 status = RangeStatus.Open for i in range(start_idx, index.shape[0]): if index[i] >= index[start_idx] + delta: end_idx = i status = RangeStatus.Closed break else: if start_idx + delta < n_rows: end_idx = start_idx + delta status = RangeStatus.Closed else: end_idx = n_rows - 1 status = RangeStatus.Open else: end_idx = r_idx status = RangeStatus.Closed if delta_use_index: if index is None: raise ValueError("Index is required") start_idx = 0 for i in range(end_idx, -1, -1): if index[i] <= index[end_idx] + delta: start_idx = i break else: if end_idx + delta >= 0: start_idx = end_idx + delta else: start_idx = 0 out["id"][r] = id_arr[r] out["col"][r] = col out["start_idx"][r] = start_idx out["end_idx"][r] = end_idx out["status"][r] = status return out @register_chunkable( size=ch.ArraySizer(arg_query="start_idx_arr", axis=0), arg_take_spec=dict( start_idx_arr=ch.ArraySlicer(axis=0), end_idx_arr=ch.ArraySlicer(axis=0), status_arr=ch.ArraySlicer(axis=0), freq=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def range_duration_nb( start_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d, status_arr: tp.Array2d, freq: int = 1, ) -> tp.Array1d: """Get duration of each range record.""" out = np.empty(start_idx_arr.shape[0], dtype=int_) for r in prange(start_idx_arr.shape[0]): if status_arr[r] == RangeStatus.Open: out[r] = end_idx_arr[r] - start_idx_arr[r] + freq else: out[r] = end_idx_arr[r] - start_idx_arr[r] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( start_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), end_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), status_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), index_lens=ch.ArraySlicer(axis=0), overlapping=None, normalize=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def range_coverage_nb( start_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d, status_arr: tp.Array2d, col_map: tp.GroupMap, index_lens: tp.Array1d, overlapping: bool = False, normalize: bool = False, ) -> tp.Array1d: """Get coverage of range records. Set `overlapping` to True to get the number of overlapping steps. Set `normalize` to True to get the number of steps in relation either to the total number of steps (when `overlapping=False`) or to the number of covered steps (when `overlapping=True`). """ col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_lens.shape[0], np.nan, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] ridxs = col_idxs[col_start_idx : col_start_idx + col_len] temp = np.full(index_lens[col], 0, dtype=int_) for r in ridxs: if status_arr[r] == RangeStatus.Open: temp[start_idx_arr[r] : end_idx_arr[r] + 1] += 1 else: temp[start_idx_arr[r] : end_idx_arr[r]] += 1 if overlapping: if normalize: pos_temp_sum = np.sum(temp > 0) if pos_temp_sum == 0: out[col] = np.nan else: out[col] = np.sum(temp > 1) / pos_temp_sum else: out[col] = np.sum(temp > 1) else: if normalize: if index_lens[col] == 0: out[col] = np.nan else: out[col] = np.sum(temp > 0) / index_lens[col] else: out[col] = np.sum(temp > 0) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( start_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), end_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), status_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), index_len=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ranges_to_mask_nb( start_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d, status_arr: tp.Array2d, col_map: tp.GroupMap, index_len: int, ) -> tp.Array2d: """Convert ranges to 2-dim mask.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full((index_len, col_lens.shape[0]), False, dtype=np.bool_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] ridxs = col_idxs[col_start_idx : col_start_idx + col_len] for r in ridxs: if status_arr[r] == RangeStatus.Open: out[start_idx_arr[r] : end_idx_arr[r] + 1, col] = True else: out[start_idx_arr[r] : end_idx_arr[r], col] = True return out @register_jitted(cache=True) def map_ranges_to_projections_nb( close: tp.Array2d, col_arr: tp.Array1d, start_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d, status_arr: tp.Array1d, index: tp.Optional[tp.Array1d] = None, proj_start: int = 0, proj_start_use_index: bool = False, proj_period: tp.Optional[int] = None, proj_period_use_index: bool = False, incl_end_idx: bool = True, extend: bool = False, rebase: bool = True, start_value: tp.FlexArray1dLike = 1.0, ffill: bool = False, remove_empty: bool = False, ) -> tp.Tuple[tp.Array1d, tp.Array2d]: """Map each range into a projection. Returns two arrays: 1. One-dimensional array where elements are record indices 2. Two-dimensional array where rows are projections""" start_value_ = to_1d_array_nb(np.asarray(start_value)) index_ranges_temp = np.empty((start_idx_arr.shape[0], 2), dtype=int_) max_duration = 0 for r in range(start_idx_arr.shape[0]): if proj_start_use_index: if index is None: raise ValueError("Index is required") r_proj_start = len(index) - start_idx_arr[r] for i in range(start_idx_arr[r], index.shape[0]): if index[i] >= index[start_idx_arr[r]] + proj_start: r_proj_start = i - start_idx_arr[r] break r_start_idx = start_idx_arr[r] + r_proj_start else: r_start_idx = start_idx_arr[r] + proj_start if status_arr[r] == RangeStatus.Open: if incl_end_idx: r_duration = end_idx_arr[r] - start_idx_arr[r] + 1 else: r_duration = end_idx_arr[r] - start_idx_arr[r] else: if incl_end_idx: r_duration = end_idx_arr[r] - start_idx_arr[r] else: r_duration = end_idx_arr[r] - start_idx_arr[r] - 1 if proj_period is None: r_end_idx = start_idx_arr[r] + r_duration else: if proj_period_use_index: if index is None: raise ValueError("Index is required") r_proj_period = -1 for i in range(r_start_idx, index.shape[0]): if index[i] <= index[r_start_idx] + proj_period: r_proj_period = i - r_start_idx else: break else: r_proj_period = proj_period if extend: r_end_idx = r_start_idx + r_proj_period else: r_end_idx = min(start_idx_arr[r] + r_duration, r_start_idx + r_proj_period) r_end_idx = r_end_idx + 1 if r_end_idx > close.shape[0]: r_end_idx = close.shape[0] if r_start_idx > r_end_idx: r_start_idx = r_end_idx if r_end_idx - r_start_idx > max_duration: max_duration = r_end_idx - r_start_idx index_ranges_temp[r, 0] = r_start_idx index_ranges_temp[r, 1] = r_end_idx ridx_out = np.empty((start_idx_arr.shape[0],), dtype=int_) proj_out = np.empty((start_idx_arr.shape[0], max_duration), dtype=float_) k = 0 for r in range(start_idx_arr.shape[0]): if extend: r_start_idx = index_ranges_temp[r, 0] r_end_idx = index_ranges_temp[r, 0] + proj_out.shape[1] else: r_start_idx = index_ranges_temp[r, 0] r_end_idx = index_ranges_temp[r, 1] r_close = close[r_start_idx:r_end_idx, col_arr[r]] any_set = False for i in range(proj_out.shape[1]): if i >= r_close.shape[0]: proj_out[k, i] = np.nan else: if rebase: if i == 0: _start_value = flex_select_1d_pc_nb(start_value_, col_arr[r]) if _start_value == -1: proj_out[k, i] = close[-1, col_arr[r]] else: proj_out[k, i] = _start_value else: if r_close[i - 1] == 0: proj_out[k, i] = np.nan else: proj_out[k, i] = proj_out[k, i - 1] * r_close[i] / r_close[i - 1] else: proj_out[k, i] = r_close[i] if not np.isnan(proj_out[k, i]) and i > 0: any_set = True if ffill and np.isnan(proj_out[k, i]) and i > 0: proj_out[k, i] = proj_out[k, i - 1] if any_set or not remove_empty: ridx_out[k] = r k += 1 if remove_empty: return ridx_out[:k], proj_out[:k] return ridx_out, proj_out @register_jitted(cache=True) def find_pattern_1d_nb( arr: tp.Array1d, pattern: tp.Array1d, window: tp.Optional[int] = None, max_window: tp.Optional[int] = None, row_select_prob: float = 1.0, window_select_prob: float = 1.0, roll_forward: bool = False, interp_mode: int = InterpMode.Mixed, rescale_mode: int = RescaleMode.MinMax, vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: int = ErrorType.Absolute, distance_measure: int = DistanceMeasure.MAE, max_error: tp.FlexArray1dLike = np.nan, max_error_interp_mode: tp.Optional[int] = None, max_error_as_maxdist: bool = False, max_error_strict: bool = False, min_pct_change: float = np.nan, max_pct_change: float = np.nan, min_similarity: float = 0.85, minp: tp.Optional[int] = None, overlap_mode: int = OverlapMode.Disallow, max_records: tp.Optional[int] = None, col: int = 0, ) -> tp.RecordArray: """Find all occurrences of a pattern in an array. Uses `vectorbtpro.generic.nb.patterns.pattern_similarity_nb` to fill records of the type `vectorbtpro.generic.enums.pattern_range_dt`. Goes through the array, and for each window selected between `window` and `max_window` (including), checks whether the window of array values is similar enough to the pattern. If so, writes a new range to the output array. If `window_select_prob` is set, decides whether to test a window based on the given probability. The same for `row_select_prob` but on rows. If `roll_forward`, windows are rolled forward (`start_idx` is guaranteed to be sorted), otherwise backward (`end_idx` is guaranteed to be sorted). By default, creates an empty record array of the same size as the number of rows in `arr`. This can be increased or decreased using `max_records`.""" max_error_ = to_1d_array_nb(np.asarray(max_error)) if window is None: window = pattern.shape[0] if max_window is None: max_window = window if max_records is None: records_out = np.empty(arr.shape[0], dtype=pattern_range_dt) else: records_out = np.empty(max_records, dtype=pattern_range_dt) r = 0 min_max_required = False if rescale_mode == RescaleMode.MinMax: min_max_required = True if not np.isnan(min_pct_change): min_max_required = True if not np.isnan(max_pct_change): min_max_required = True if not max_error_as_maxdist: min_max_required = True if min_max_required: if np.isnan(pmin): pmin = np.nanmin(pattern) if np.isnan(pmax): pmax = np.nanmax(pattern) for i in range(arr.shape[0]): if roll_forward: from_i = i to_i = i + window if to_i > arr.shape[0]: break else: from_i = i - window + 1 to_i = i + 1 if from_i < 0: continue if np.random.uniform(0, 1) < row_select_prob: _vmin = vmin _vmax = vmax if min_max_required: if np.isnan(_vmin) or np.isnan(_vmax): for j in range(from_i, to_i): if np.isnan(_vmin) or arr[j] < _vmin: _vmin = arr[j] if np.isnan(_vmax) or arr[j] > _vmax: _vmax = arr[j] for w in range(window, max_window + 1): if roll_forward: from_i = i to_i = i + w if to_i > arr.shape[0]: break if min_max_required: if w > window: if arr[to_i - 1] < _vmin: _vmin = arr[to_i - 1] if arr[to_i - 1] > _vmax: _vmax = arr[to_i - 1] else: from_i = i - w + 1 to_i = i + 1 if from_i < 0: continue if min_max_required: if w > window: if arr[from_i] < _vmin: _vmin = arr[from_i] if arr[from_i] > _vmax: _vmax = arr[from_i] if np.random.uniform(0, 1) < window_select_prob: arr_window = arr[from_i:to_i] similarity = pattern_similarity_nb( arr_window, pattern, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=_vmin, vmax=_vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, distance_measure=distance_measure, max_error=max_error_, max_error_interp_mode=max_error_interp_mode, max_error_as_maxdist=max_error_as_maxdist, max_error_strict=max_error_strict, min_pct_change=min_pct_change, max_pct_change=max_pct_change, min_similarity=min_similarity, minp=minp, ) if not np.isnan(similarity): skip = False while True: if r > 0: if roll_forward: prev_same_row = records_out["start_idx"][r - 1] == from_i else: prev_same_row = records_out["end_idx"][r - 1] == to_i if overlap_mode != OverlapMode.AllowAll and prev_same_row: if similarity > records_out["similarity"][r - 1]: r -= 1 continue else: skip = True break elif overlap_mode >= 0: overlap = records_out["end_idx"][r - 1] - from_i if overlap > overlap_mode: if similarity > records_out["similarity"][r - 1]: r -= 1 continue else: skip = True break break if skip: continue if r >= records_out.shape[0]: raise IndexError("Records index out of range. Set a higher max_records.") records_out["id"][r] = r records_out["col"][r] = col records_out["start_idx"][r] = from_i if to_i <= arr.shape[0] - 1: records_out["end_idx"][r] = to_i records_out["status"][r] = RangeStatus.Closed else: records_out["end_idx"][r] = arr.shape[0] - 1 records_out["status"][r] = RangeStatus.Open records_out["similarity"][r] = similarity r += 1 return records_out[:r] @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), pattern=None, window=None, max_window=None, row_select_prob=None, window_select_prob=None, roll_forward=None, interp_mode=None, rescale_mode=None, vmin=None, vmax=None, pmin=None, pmax=None, invert=None, error_type=None, distance_measure=None, max_error=None, max_error_interp_mode=None, max_error_as_maxdist=None, max_error_strict=None, min_pct_change=None, max_pct_change=None, min_similarity=None, minp=None, overlap_mode=None, max_records=None, ), merge_func=records_ch.merge_records, ) @register_jitted(cache=True, tags={"can_parallel"}) def find_pattern_nb( arr: tp.Array2d, pattern: tp.Array1d, window: tp.Optional[int] = None, max_window: tp.Optional[int] = None, row_select_prob: float = 1.0, window_select_prob: float = 1.0, roll_forward: bool = False, interp_mode: int = InterpMode.Mixed, rescale_mode: int = RescaleMode.MinMax, vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: int = ErrorType.Absolute, distance_measure: int = DistanceMeasure.MAE, max_error: tp.FlexArray1dLike = np.nan, max_error_interp_mode: tp.Optional[int] = None, max_error_as_maxdist: bool = False, max_error_strict: bool = False, min_pct_change: float = np.nan, max_pct_change: float = np.nan, min_similarity: float = 0.85, minp: tp.Optional[int] = None, overlap_mode: int = OverlapMode.Disallow, max_records: tp.Optional[int] = None, ) -> tp.RecordArray: """2-dim version of `find_pattern_1d_nb`.""" max_error_ = to_1d_array_nb(np.asarray(max_error)) if window is None: window = pattern.shape[0] if max_window is None: max_window = window if max_records is None: records_out = np.empty((arr.shape[0], arr.shape[1]), dtype=pattern_range_dt) else: records_out = np.empty((max_records, arr.shape[1]), dtype=pattern_range_dt) record_counts = np.full(arr.shape[1], 0, dtype=int_) for col in prange(arr.shape[1]): records = find_pattern_1d_nb( arr[:, col], pattern, window=window, max_window=max_window, row_select_prob=row_select_prob, window_select_prob=window_select_prob, roll_forward=roll_forward, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=vmin, vmax=vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, distance_measure=distance_measure, max_error=max_error_, max_error_interp_mode=max_error_interp_mode, max_error_as_maxdist=max_error_as_maxdist, max_error_strict=max_error_strict, min_pct_change=min_pct_change, max_pct_change=max_pct_change, min_similarity=min_similarity, minp=minp, overlap_mode=overlap_mode, max_records=max_records, col=col, ) record_counts[col] = records.shape[0] records_out[: records.shape[0], col] = records return repartition_nb(records_out, record_counts) # ############# Drawdowns ############# # @register_jitted(cache=True) def drawdown_1d_nb(arr: tp.Array1d) -> tp.Array1d: """Compute drawdown.""" out = np.empty_like(arr, dtype=float_) max_val = np.nan for i in range(arr.shape[0]): if np.isnan(max_val) or arr[i] > max_val: max_val = arr[i] if np.isnan(max_val) or max_val == 0: out[i] = np.nan else: out[i] = arr[i] / max_val - 1 return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def drawdown_nb(arr: tp.Array2d) -> tp.Array2d: """2-dim version of `drawdown_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = drawdown_1d_nb(arr[:, col]) return out @register_jitted(cache=True) def fill_drawdown_record_nb( new_records: tp.RecordArray2d, counts: tp.Array2d, i: int, col: int, start_idx: int, valley_idx: int, start_val: float, valley_val: float, end_val: float, status: int, ): """Fill a drawdown record.""" r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = start_idx new_records["valley_idx"][r, col] = valley_idx new_records["end_idx"][r, col] = i new_records["start_val"][r, col] = start_val new_records["valley_val"][r, col] = valley_val new_records["end_val"][r, col] = end_val new_records["status"][r, col] = status counts[col] += 1 @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( open=ch.ArraySlicer(axis=1), high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def get_drawdowns_nb( open: tp.Optional[tp.Array2d], high: tp.Optional[tp.Array2d], low: tp.Optional[tp.Array2d], close: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.RecordArray: """Fill drawdown records by analyzing a time series. Only `close` must be provided, other time series are optional. Usage: ```pycon >>> from vectorbtpro import * >>> close = np.array([ ... [1, 5, 1, 3], ... [2, 4, 2, 2], ... [3, 3, 3, 1], ... [4, 2, 2, 2], ... [5, 1, 1, 3] ... ]) >>> records = vbt.nb.get_drawdowns_nb(None, None, None, close) >>> pd.DataFrame.from_records(records) id col start_idx valley_idx end_idx start_val valley_val end_val \\ 0 0 1 0 4 4 5.0 1.0 1.0 1 0 2 2 4 4 3.0 1.0 1.0 2 0 3 0 2 4 3.0 1.0 3.0 status 0 0 1 0 2 1 ``` """ new_records = np.empty(close.shape, dtype=drawdown_dt) counts = np.full(close.shape[1], 0, dtype=int_) sim_start_, sim_end_ = prepare_sim_range_nb( sim_shape=close.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(close.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue drawdown_started = False _close = close[0, col] if open is None: _open = np.nan else: _open = open[0, col] start_idx = 0 valley_idx = 0 start_val = _open valley_val = _open for i in range(_sim_start, _sim_end): _close = close[i, col] if open is None: _open = np.nan else: _open = open[i, col] if high is None: _high = np.nan else: _high = high[i, col] if low is None: _low = np.nan else: _low = low[i, col] if np.isnan(_high): if np.isnan(_open): _high = _close elif np.isnan(_close): _high = _open else: _high = max(_open, _close) if np.isnan(_low): if np.isnan(_open): _low = _close elif np.isnan(_close): _low = _open else: _low = min(_open, _close) if drawdown_started: if _open >= start_val: drawdown_started = False fill_drawdown_record_nb( new_records=new_records, counts=counts, i=i, col=col, start_idx=start_idx, valley_idx=valley_idx, start_val=start_val, valley_val=valley_val, end_val=_open, status=DrawdownStatus.Recovered, ) start_idx = i valley_idx = i start_val = _open valley_val = _open if drawdown_started: if _low < valley_val: valley_idx = i valley_val = _low if _high >= start_val: drawdown_started = False fill_drawdown_record_nb( new_records=new_records, counts=counts, i=i, col=col, start_idx=start_idx, valley_idx=valley_idx, start_val=start_val, valley_val=valley_val, end_val=_high, status=DrawdownStatus.Recovered, ) start_idx = i valley_idx = i start_val = _high valley_val = _high else: if np.isnan(start_val) or _high >= start_val: start_idx = i valley_idx = i start_val = _high valley_val = _high elif _low < valley_val: if not np.isnan(valley_val): drawdown_started = True valley_idx = i valley_val = _low if drawdown_started: if i == _sim_end - 1: drawdown_started = False fill_drawdown_record_nb( new_records=new_records, counts=counts, i=i, col=col, start_idx=start_idx, valley_idx=valley_idx, start_val=start_val, valley_val=valley_val, end_val=_close, status=DrawdownStatus.Active, ) return repartition_nb(new_records, counts) @register_chunkable( size=ch.ArraySizer(arg_query="start_val_arr", axis=0), arg_take_spec=dict(start_val_arr=ch.ArraySlicer(axis=0), valley_val_arr=ch.ArraySlicer(axis=0)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def dd_drawdown_nb(start_val_arr: tp.Array1d, valley_val_arr: tp.Array1d) -> tp.Array1d: """Compute the drawdown of each drawdown record.""" out = np.empty(valley_val_arr.shape[0], dtype=float_) for r in prange(valley_val_arr.shape[0]): if start_val_arr[r] == 0: out[r] = np.nan else: out[r] = (valley_val_arr[r] - start_val_arr[r]) / start_val_arr[r] return out @register_chunkable( size=ch.ArraySizer(arg_query="start_idx_arr", axis=0), arg_take_spec=dict(start_idx_arr=ch.ArraySlicer(axis=0), valley_idx_arr=ch.ArraySlicer(axis=0)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def dd_decline_duration_nb(start_idx_arr: tp.Array1d, valley_idx_arr: tp.Array1d) -> tp.Array1d: """Compute the duration of the peak-to-valley phase of each drawdown record.""" out = np.empty(valley_idx_arr.shape[0], dtype=float_) for r in prange(valley_idx_arr.shape[0]): out[r] = valley_idx_arr[r] - start_idx_arr[r] return out @register_chunkable( size=ch.ArraySizer(arg_query="valley_idx_arr", axis=0), arg_take_spec=dict(valley_idx_arr=ch.ArraySlicer(axis=0), end_idx_arr=ch.ArraySlicer(axis=0)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def dd_recovery_duration_nb(valley_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d) -> tp.Array1d: """Compute the duration of the valley-to-recovery phase of each drawdown record.""" out = np.empty(end_idx_arr.shape[0], dtype=float_) for r in prange(end_idx_arr.shape[0]): out[r] = end_idx_arr[r] - valley_idx_arr[r] return out @register_chunkable( size=ch.ArraySizer(arg_query="start_idx_arr", axis=0), arg_take_spec=dict( start_idx_arr=ch.ArraySlicer(axis=0), valley_idx_arr=ch.ArraySlicer(axis=0), end_idx_arr=ch.ArraySlicer(axis=0), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def dd_recovery_duration_ratio_nb( start_idx_arr: tp.Array1d, valley_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d, ) -> tp.Array1d: """Compute the ratio of the recovery duration to the decline duration of each drawdown record.""" out = np.empty(start_idx_arr.shape[0], dtype=float_) for r in prange(start_idx_arr.shape[0]): if valley_idx_arr[r] - start_idx_arr[r] == 0: out[r] = np.nan else: out[r] = (end_idx_arr[r] - valley_idx_arr[r]) / (valley_idx_arr[r] - start_idx_arr[r]) return out @register_chunkable( size=ch.ArraySizer(arg_query="valley_val_arr", axis=0), arg_take_spec=dict(valley_val_arr=ch.ArraySlicer(axis=0), end_val_arr=ch.ArraySlicer(axis=0)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def dd_recovery_return_nb(valley_val_arr: tp.Array1d, end_val_arr: tp.Array1d) -> tp.Array1d: """Compute the recovery return of each drawdown record.""" out = np.empty(end_val_arr.shape[0], dtype=float_) for r in prange(end_val_arr.shape[0]): if valley_val_arr[r] == 0: out[r] = np.nan else: out[r] = (end_val_arr[r] - valley_val_arr[r]) / valley_val_arr[r] return out @register_jitted(cache=True) def bar_price_nb(records: tp.RecordArray, price: tp.Optional[tp.FlexArray2d]) -> tp.Array1d: """Return the bar price.""" out = np.empty(len(records), dtype=float_) for i in range(len(records)): record = records[i] if price is not None: out[i] = float(flex_select_nb(price, record["idx"], record["col"])) else: out[i] = np.nan return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for rolling and expanding windows.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.generic.enums import * from vectorbtpro.generic.nb.base import rank_1d_nb from vectorbtpro.generic.nb.patterns import pattern_similarity_nb from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch # ############# Rolling functions ############# # @register_jitted(cache=True) def rolling_sum_acc_nb(in_state: RollSumAIS) -> RollSumAOS: """Accumulator of `rolling_sum_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollSumAIS` and returns a state of type `vectorbtpro.generic.enums.RollSumAOS`.""" i = in_state.i value = in_state.value pre_window_value = in_state.pre_window_value cumsum = in_state.cumsum nancnt = in_state.nancnt window = in_state.window minp = in_state.minp if np.isnan(value): nancnt = nancnt + 1 else: cumsum = cumsum + value if i < window: window_len = i + 1 - nancnt window_cumsum = cumsum else: if np.isnan(pre_window_value): nancnt = nancnt - 1 else: cumsum = cumsum - pre_window_value window_len = window - nancnt window_cumsum = cumsum if window_len < minp: value = np.nan else: value = window_cumsum return RollSumAOS(cumsum=cumsum, nancnt=nancnt, window_len=window_len, value=value) @register_jitted(cache=True) def rolling_sum_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute rolling sum. Uses `rolling_sum_acc_nb` at each iteration. Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).sum()`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) cumsum = 0.0 nancnt = 0 for i in range(arr.shape[0]): in_state = RollSumAIS( i=i, value=arr[i], pre_window_value=arr[i - window] if i - window >= 0 else np.nan, cumsum=cumsum, nancnt=nancnt, window=window, minp=minp, ) out_state = rolling_sum_acc_nb(in_state) cumsum = out_state.cumsum nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_sum_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `rolling_sum_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_sum_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def rolling_prod_acc_nb(in_state: RollProdAIS) -> RollProdAOS: """Accumulator of `rolling_prod_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollProdAIS` and returns a state of type `vectorbtpro.generic.enums.RollProdAOS`.""" i = in_state.i value = in_state.value pre_window_value = in_state.pre_window_value cumprod = in_state.cumprod nancnt = in_state.nancnt window = in_state.window minp = in_state.minp if np.isnan(value): nancnt = nancnt + 1 else: cumprod = cumprod * value if i < window: window_len = i + 1 - nancnt window_cumprod = cumprod else: if np.isnan(pre_window_value): nancnt = nancnt - 1 else: cumprod = cumprod / pre_window_value window_len = window - nancnt window_cumprod = cumprod if window_len < minp: value = np.nan else: value = window_cumprod return RollProdAOS(cumprod=cumprod, nancnt=nancnt, window_len=window_len, value=value) @register_jitted(cache=True) def rolling_prod_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute rolling product. Uses `rolling_prod_acc_nb` at each iteration. Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).apply(np.prod)`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) cumprod = 1.0 nancnt = 0 for i in range(arr.shape[0]): in_state = RollProdAIS( i=i, value=arr[i], pre_window_value=arr[i - window] if i - window >= 0 else np.nan, cumprod=cumprod, nancnt=nancnt, window=window, minp=minp, ) out_state = rolling_prod_acc_nb(in_state) cumprod = out_state.cumprod nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_prod_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `rolling_prod_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_prod_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def rolling_mean_acc_nb(in_state: RollMeanAIS) -> RollMeanAOS: """Accumulator of `rolling_mean_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollMeanAIS` and returns a state of type `vectorbtpro.generic.enums.RollMeanAOS`.""" i = in_state.i value = in_state.value pre_window_value = in_state.pre_window_value cumsum = in_state.cumsum nancnt = in_state.nancnt window = in_state.window minp = in_state.minp if np.isnan(value): nancnt = nancnt + 1 else: cumsum = cumsum + value if i < window: window_len = i + 1 - nancnt window_cumsum = cumsum else: if np.isnan(pre_window_value): nancnt = nancnt - 1 else: cumsum = cumsum - pre_window_value window_len = window - nancnt window_cumsum = cumsum if window_len < minp: value = np.nan else: value = window_cumsum / window_len return RollMeanAOS(cumsum=cumsum, nancnt=nancnt, window_len=window_len, value=value) @register_jitted(cache=True) def rolling_mean_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute rolling mean. Uses `rolling_mean_acc_nb` at each iteration. Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).mean()`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) cumsum = 0.0 nancnt = 0 for i in range(arr.shape[0]): in_state = RollMeanAIS( i=i, value=arr[i], pre_window_value=arr[i - window] if i - window >= 0 else np.nan, cumsum=cumsum, nancnt=nancnt, window=window, minp=minp, ) out_state = rolling_mean_acc_nb(in_state) cumsum = out_state.cumsum nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_mean_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `rolling_mean_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_mean_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def rolling_std_acc_nb(in_state: RollStdAIS) -> RollStdAOS: """Accumulator of `rolling_std_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollStdAIS` and returns a state of type `vectorbtpro.generic.enums.RollStdAOS`.""" i = in_state.i value = in_state.value pre_window_value = in_state.pre_window_value cumsum = in_state.cumsum cumsum_sq = in_state.cumsum_sq nancnt = in_state.nancnt window = in_state.window minp = in_state.minp ddof = in_state.ddof if np.isnan(value): nancnt = nancnt + 1 else: cumsum = cumsum + value cumsum_sq = cumsum_sq + value**2 if i < window: window_len = i + 1 - nancnt else: if np.isnan(pre_window_value): nancnt = nancnt - 1 else: cumsum = cumsum - pre_window_value cumsum_sq = cumsum_sq - pre_window_value**2 window_len = window - nancnt if window_len < minp or window_len == ddof: value = np.nan else: mean = cumsum / window_len value = np.sqrt(np.abs(cumsum_sq - 2 * cumsum * mean + window_len * mean**2) / (window_len - ddof)) return RollStdAOS(cumsum=cumsum, cumsum_sq=cumsum_sq, nancnt=nancnt, window_len=window_len, value=value) @register_jitted(cache=True) def rolling_std_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array1d: """Compute rolling standard deviation. Uses `rolling_std_acc_nb` at each iteration. Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).std(ddof=ddof)`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) cumsum = 0.0 cumsum_sq = 0.0 nancnt = 0 for i in range(arr.shape[0]): in_state = RollStdAIS( i=i, value=arr[i], pre_window_value=arr[i - window] if i - window >= 0 else np.nan, cumsum=cumsum, cumsum_sq=cumsum_sq, nancnt=nancnt, window=window, minp=minp, ddof=ddof, ) out_state = rolling_std_acc_nb(in_state) cumsum = out_state.cumsum cumsum_sq = out_state.cumsum_sq nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, ddof=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_std_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array2d: """2-dim version of `rolling_std_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_std_1d_nb(arr[:, col], window, minp=minp, ddof=ddof) return out @register_jitted(cache=True) def rolling_zscore_acc_nb(in_state: RollZScoreAIS) -> RollZScoreAOS: """Accumulator of `rolling_zscore_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollZScoreAIS` and returns a state of type `vectorbtpro.generic.enums.RollZScoreAOS`.""" mean_in_state = RollMeanAIS( i=in_state.i, value=in_state.value, pre_window_value=in_state.pre_window_value, cumsum=in_state.cumsum, nancnt=in_state.nancnt, window=in_state.window, minp=in_state.minp, ) std_in_state = RollStdAIS( i=in_state.i, value=in_state.value, pre_window_value=in_state.pre_window_value, cumsum=in_state.cumsum, cumsum_sq=in_state.cumsum_sq, nancnt=in_state.nancnt, window=in_state.window, minp=in_state.minp, ddof=in_state.ddof, ) mean_out_state = rolling_mean_acc_nb(mean_in_state) std_out_state = rolling_std_acc_nb(std_in_state) if std_out_state.value == 0: value = np.nan else: value = (in_state.value - mean_out_state.value) / std_out_state.value return RollZScoreAOS( cumsum=std_out_state.cumsum, cumsum_sq=std_out_state.cumsum_sq, nancnt=std_out_state.nancnt, window_len=std_out_state.window_len, value=value, ) @register_jitted(cache=True) def rolling_zscore_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array1d: """Compute rolling z-score. Uses `rolling_zscore_acc_nb` at each iteration.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) cumsum = 0.0 cumsum_sq = 0.0 nancnt = 0 for i in range(arr.shape[0]): in_state = RollZScoreAIS( i=i, value=arr[i], pre_window_value=arr[i - window] if i - window >= 0 else np.nan, cumsum=cumsum, cumsum_sq=cumsum_sq, nancnt=nancnt, window=window, minp=minp, ddof=ddof, ) out_state = rolling_zscore_acc_nb(in_state) cumsum = out_state.cumsum cumsum_sq = out_state.cumsum_sq nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, ddof=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_zscore_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array2d: """2-dim version of `rolling_zscore_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_zscore_1d_nb(arr[:, col], window, minp=minp, ddof=ddof) return out @register_jitted(cache=True) def wm_mean_acc_nb(in_state: WMMeanAIS) -> WMMeanAOS: """Accumulator of `wm_mean_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.WMMeanAIS` and returns a state of type `vectorbtpro.generic.enums.WMMeanAOS`.""" i = in_state.i value = in_state.value pre_window_value = in_state.pre_window_value cumsum = in_state.cumsum wcumsum = in_state.wcumsum nancnt = in_state.nancnt window = in_state.window minp = in_state.minp if i >= window and not np.isnan(pre_window_value): wcumsum = wcumsum - cumsum if np.isnan(value): nancnt = nancnt + 1 else: cumsum = cumsum + value if i < window: window_len = i + 1 - nancnt else: if np.isnan(pre_window_value): nancnt = nancnt - 1 else: cumsum = cumsum - pre_window_value window_len = window - nancnt if not np.isnan(value): wcumsum = wcumsum + value * window_len if window_len < minp: value = np.nan else: value = wcumsum * 2 / (window_len + 1) / window_len return WMMeanAOS(cumsum=cumsum, wcumsum=wcumsum, nancnt=nancnt, window_len=window_len, value=value) @register_jitted(cache=True) def wm_mean_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute weighted moving average. Uses `wm_mean_acc_nb` at each iteration.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) cumsum = 0.0 wcumsum = 0.0 nancnt = 0 for i in range(arr.shape[0]): in_state = WMMeanAIS( i=i, value=arr[i], pre_window_value=arr[i - window] if i - window >= 0 else np.nan, cumsum=cumsum, wcumsum=wcumsum, nancnt=nancnt, window=window, minp=minp, ) out_state = wm_mean_acc_nb(in_state) cumsum = out_state.cumsum wcumsum = out_state.wcumsum nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def wm_mean_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `wm_mean_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = wm_mean_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def alpha_from_com_nb(com: float) -> float: """Get the smoothing factor `alpha` from a center of mass.""" return 1.0 / (1.0 + com) @register_jitted(cache=True) def alpha_from_span_nb(span: float) -> float: """Get the smoothing factor `alpha` from a span.""" com = (span - 1) / 2.0 return alpha_from_com_nb(com) @register_jitted(cache=True) def alpha_from_halflife_nb(halflife: float) -> float: """Get the smoothing factor `alpha` from a half-life.""" return 1 - np.exp(-np.log(2) / halflife) @register_jitted(cache=True) def alpha_from_wilder_nb(period: int) -> float: """Get the smoothing factor `alpha` from a Wilder's period.""" return 1 / period @register_jitted(cache=True) def ewm_mean_acc_nb(in_state: EWMMeanAIS) -> EWMMeanAOS: """Accumulator of `ewm_mean_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.EWMMeanAIS` and returns a state of type `vectorbtpro.generic.enums.EWMMeanAOS`.""" i = in_state.i value = in_state.value old_wt = in_state.old_wt weighted_avg = in_state.weighted_avg nobs = in_state.nobs alpha = in_state.alpha minp = in_state.minp adjust = in_state.adjust old_wt_factor = 1.0 - alpha new_wt = 1.0 if adjust else alpha if i > 0: is_observation = not np.isnan(value) nobs += is_observation if not np.isnan(weighted_avg): old_wt *= old_wt_factor if is_observation: # avoid numerical errors on constant series if weighted_avg != value: weighted_avg = ((old_wt * weighted_avg) + (new_wt * value)) / (old_wt + new_wt) if adjust: old_wt += new_wt else: old_wt = 1.0 elif is_observation: weighted_avg = value else: is_observation = not np.isnan(weighted_avg) nobs += int(is_observation) value = weighted_avg if (nobs >= minp) else np.nan return EWMMeanAOS(old_wt=old_wt, weighted_avg=weighted_avg, nobs=nobs, value=value) @register_jitted(cache=True) def ewm_mean_1d_nb(arr: tp.Array1d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d: """Compute exponential weighted moving average. Uses `ewm_mean_acc_nb` at each iteration. Numba equivalent to `pd.Series(arr).ewm(span=span, min_periods=minp, adjust=adjust).mean()`. Adaptation of `pd._libs.window.aggregations.window_aggregations.ewma` with default arguments. !!! note In contrast to the Pandas implementation, `minp` is effective within `span`.""" if minp is None: minp = span if minp > span: raise ValueError("minp must be <= span") out = np.empty(len(arr), dtype=float_) if len(arr) == 0: return out com = (span - 1) / 2.0 alpha = 1.0 / (1.0 + com) weighted_avg = float(arr[0]) + 0.0 # cast to float_ nobs = 0 n_obs_lagged = 0 old_wt = 1.0 for i in range(len(arr)): if i >= span: if not np.isnan(arr[i - span]): n_obs_lagged += 1 in_state = EWMMeanAIS( i=i, value=arr[i], old_wt=old_wt, weighted_avg=weighted_avg, nobs=nobs - n_obs_lagged, alpha=alpha, minp=minp, adjust=adjust, ) out_state = ewm_mean_acc_nb(in_state) old_wt = out_state.old_wt weighted_avg = out_state.weighted_avg nobs = out_state.nobs + n_obs_lagged out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), span=None, minp=None, adjust=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ewm_mean_nb(arr: tp.Array2d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d: """2-dim version of `ewm_mean_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = ewm_mean_1d_nb(arr[:, col], span, minp=minp, adjust=adjust) return out @register_jitted(cache=True) def ewm_std_acc_nb(in_state: EWMStdAIS) -> EWMStdAOS: """Accumulator of `ewm_std_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.EWMStdAIS` and returns a state of type `vectorbtpro.generic.enums.EWMStdAOS`.""" i = in_state.i value = in_state.value mean_x = in_state.mean_x mean_y = in_state.mean_y cov = in_state.cov sum_wt = in_state.sum_wt sum_wt2 = in_state.sum_wt2 old_wt = in_state.old_wt nobs = in_state.nobs alpha = in_state.alpha minp = in_state.minp adjust = in_state.adjust old_wt_factor = 1.0 - alpha new_wt = 1.0 if adjust else alpha cur_x = value cur_y = value is_observation = not np.isnan(cur_x) and not np.isnan(cur_y) nobs += is_observation if i > 0: if not np.isnan(mean_x): sum_wt *= old_wt_factor sum_wt2 *= old_wt_factor * old_wt_factor old_wt *= old_wt_factor if is_observation: old_mean_x = mean_x old_mean_y = mean_y # avoid numerical errors on constant series if mean_x != cur_x: mean_x = ((old_wt * old_mean_x) + (new_wt * cur_x)) / (old_wt + new_wt) # avoid numerical errors on constant series if mean_y != cur_y: mean_y = ((old_wt * old_mean_y) + (new_wt * cur_y)) / (old_wt + new_wt) cov = ( (old_wt * (cov + ((old_mean_x - mean_x) * (old_mean_y - mean_y)))) + (new_wt * ((cur_x - mean_x) * (cur_y - mean_y))) ) / (old_wt + new_wt) sum_wt += new_wt sum_wt2 += new_wt * new_wt old_wt += new_wt if not adjust: sum_wt /= old_wt sum_wt2 /= old_wt * old_wt old_wt = 1.0 elif is_observation: mean_x = cur_x mean_y = cur_y else: if not is_observation: mean_x = np.nan mean_y = np.nan if nobs >= minp: numerator = sum_wt * sum_wt denominator = numerator - sum_wt2 if denominator > 0.0: value = (numerator / denominator) * cov else: value = np.nan else: value = np.nan return EWMStdAOS( mean_x=mean_x, mean_y=mean_y, cov=cov, sum_wt=sum_wt, sum_wt2=sum_wt2, old_wt=old_wt, nobs=nobs, value=value, ) @register_jitted(cache=True) def ewm_std_1d_nb(arr: tp.Array1d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d: """Compute exponential weighted moving standard deviation. Uses `ewm_std_acc_nb` at each iteration. Numba equivalent to `pd.Series(arr).ewm(span=span, min_periods=minp).std()`. Adaptation of `pd._libs.window.aggregations.window_aggregations.ewmcov` with default arguments. !!! note In contrast to the Pandas implementation, `minp` is effective within `span`.""" if minp is None: minp = span if minp > span: raise ValueError("minp must be <= span") out = np.empty(len(arr), dtype=float_) if len(arr) == 0: return out com = (span - 1) / 2.0 alpha = 1.0 / (1.0 + com) mean_x = float(arr[0]) + 0.0 # cast to float_ mean_y = float(arr[0]) + 0.0 # cast to float_ nobs = 0 n_obs_lagged = 0 cov = 0.0 sum_wt = 1.0 sum_wt2 = 1.0 old_wt = 1.0 for i in range(len(arr)): if i >= span: if not np.isnan(arr[i - span]): n_obs_lagged += 1 in_state = EWMStdAIS( i=i, value=arr[i], mean_x=mean_x, mean_y=mean_y, cov=cov, sum_wt=sum_wt, sum_wt2=sum_wt2, old_wt=old_wt, nobs=nobs - n_obs_lagged, alpha=alpha, minp=minp, adjust=adjust, ) out_state = ewm_std_acc_nb(in_state) mean_x = out_state.mean_x mean_y = out_state.mean_y cov = out_state.cov sum_wt = out_state.sum_wt sum_wt2 = out_state.sum_wt2 old_wt = out_state.old_wt nobs = out_state.nobs + n_obs_lagged out[i] = out_state.value return np.sqrt(out) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), span=None, minp=None, adjust=None, ddof=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ewm_std_nb(arr: tp.Array2d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d: """2-dim version of `ewm_std_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = ewm_std_1d_nb(arr[:, col], span, minp=minp, adjust=adjust) return out @register_jitted(cache=True) def wwm_mean_1d_nb(arr: tp.Array1d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d: """Compute Wilder's exponential weighted moving average.""" if minp is None: minp = period return ewm_mean_1d_nb(arr, 2 * period - 1, minp=minp, adjust=adjust) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), period=None, minp=None, adjust=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def wwm_mean_nb(arr: tp.Array2d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d: """2-dim version of `wwm_mean_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = wwm_mean_1d_nb(arr[:, col], period, minp=minp, adjust=adjust) return out @register_jitted(cache=True) def wwm_std_1d_nb(arr: tp.Array1d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d: """Compute Wilder's exponential weighted moving standard deviation.""" if minp is None: minp = period return ewm_std_1d_nb(arr, 2 * period - 1, minp=minp, adjust=adjust) @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), period=None, minp=None, adjust=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def wwm_std_nb(arr: tp.Array2d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d: """2-dim version of `wwm_std_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = wwm_std_1d_nb(arr[:, col], period, minp=minp, adjust=adjust) return out @register_jitted(cache=True) def vidya_acc_nb(in_state: VidyaAIS) -> VidyaAOS: """Accumulator of `vidya_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.VidyaAIS` and returns a state of type `vectorbtpro.generic.enums.VidyaAOS`.""" i = in_state.i prev_value = in_state.prev_value value = in_state.value pre_window_prev_value = in_state.pre_window_prev_value pre_window_value = in_state.pre_window_value pos_cumsum = in_state.pos_cumsum neg_cumsum = in_state.neg_cumsum prev_vidya = in_state.prev_vidya nancnt = in_state.nancnt window = in_state.window minp = in_state.minp alpha = 2 / (window + 1) diff = value - prev_value if np.isnan(diff): nancnt = nancnt + 1 else: if diff > 0: pos_value = diff neg_value = 0.0 else: pos_value = 0.0 neg_value = abs(diff) pos_cumsum = pos_cumsum + pos_value neg_cumsum = neg_cumsum + neg_value if i < window: window_len = i + 1 - nancnt else: pre_window_diff = pre_window_value - pre_window_prev_value if np.isnan(pre_window_diff): nancnt = nancnt - 1 else: if pre_window_diff > 0: pre_window_pos_value = pre_window_diff pre_window_neg_value = 0.0 else: pre_window_pos_value = 0.0 pre_window_neg_value = abs(pre_window_diff) pos_cumsum = pos_cumsum - pre_window_pos_value neg_cumsum = neg_cumsum - pre_window_neg_value window_len = window - nancnt window_pos_cumsum = pos_cumsum window_neg_cumsum = neg_cumsum if window_len < minp: cmo = np.nan vidya = np.nan else: sh = window_pos_cumsum sl = window_neg_cumsum if sh + sl == 0: cmo = 0.0 else: cmo = np.abs((sh - sl) / (sh + sl)) if np.isnan(prev_vidya): prev_vidya = 0.0 vidya = alpha * cmo * value + prev_vidya * (1 - alpha * cmo) return VidyaAOS( pos_cumsum=pos_cumsum, neg_cumsum=neg_cumsum, nancnt=nancnt, window_len=window_len, cmo=cmo, vidya=vidya, ) @register_jitted(cache=True) def vidya_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute VIDYA. Uses `vidya_acc_nb` at each iteration.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) pos_cumsum = 0.0 neg_cumsum = 0.0 nancnt = 0 for i in range(arr.shape[0]): in_state = VidyaAIS( i=i, prev_value=arr[i - 1] if i - 1 >= 0 else np.nan, value=arr[i], pre_window_prev_value=arr[i - window - 1] if i - window - 1 >= 0 else np.nan, pre_window_value=arr[i - window] if i - window >= 0 else np.nan, pos_cumsum=pos_cumsum, neg_cumsum=neg_cumsum, prev_vidya=out[i - 1] if i - 1 >= 0 else np.nan, nancnt=nancnt, window=window, minp=minp, ) out_state = vidya_acc_nb(in_state) pos_cumsum = out_state.pos_cumsum neg_cumsum = out_state.neg_cumsum nancnt = out_state.nancnt out[i] = out_state.vidya return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def vidya_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `vidya_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = vidya_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def ma_1d_nb( arr: tp.Array1d, window: int, wtype: int = WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """Compute a moving average based on the mode of the type `vectorbtpro.generic.enums.WType`.""" if wtype == WType.Simple: return rolling_mean_1d_nb(arr, window, minp=minp) if wtype == WType.Weighted: return wm_mean_1d_nb(arr, window, minp=minp) if wtype == WType.Exp: return ewm_mean_1d_nb(arr, window, minp=minp, adjust=adjust) if wtype == WType.Wilder: return wwm_mean_1d_nb(arr, window, minp=minp, adjust=adjust) if wtype == WType.Vidya: return vidya_1d_nb(arr, window, minp=minp) raise ValueError("Invalid rolling mode") @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, wtype=None, minp=None, adjust=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ma_nb( arr: tp.Array2d, window: int, wtype: int = WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `ma_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = ma_1d_nb(arr[:, col], window, wtype=wtype, minp=minp, adjust=adjust) return out @register_jitted(cache=True) def msd_1d_nb( arr: tp.Array1d, window: int, wtype: int = WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Array1d: """Compute a moving standard deviation based on the mode of the type `vectorbtpro.generic.enums.WType`.""" if wtype == WType.Simple: return rolling_std_1d_nb(arr, window, minp=minp, ddof=ddof) if wtype == WType.Weighted: raise ValueError("Weighted mode is not supported for standard deviations") if wtype == WType.Exp: return ewm_std_1d_nb(arr, window, minp=minp, adjust=adjust) if wtype == WType.Wilder: return wwm_std_1d_nb(arr, window, minp=minp, adjust=adjust) raise ValueError("Invalid rolling mode") @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, wtype=None, minp=None, adjust=None, ddof=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def msd_nb( arr: tp.Array2d, window: int, wtype: int = WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Array2d: """2-dim version of `msd_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = msd_1d_nb(arr[:, col], window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof) return out @register_jitted(cache=True) def rolling_cov_acc_nb(in_state: RollCovAIS) -> RollCovAOS: """Accumulator of `rolling_cov_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollCovAIS` and returns a state of type `vectorbtpro.generic.enums.RollCovAOS`.""" i = in_state.i value1 = in_state.value1 value2 = in_state.value2 pre_window_value1 = in_state.pre_window_value1 pre_window_value2 = in_state.pre_window_value2 cumsum1 = in_state.cumsum1 cumsum2 = in_state.cumsum2 cumsum_prod = in_state.cumsum_prod nancnt = in_state.nancnt window = in_state.window minp = in_state.minp ddof = in_state.ddof if np.isnan(value1) or np.isnan(value2): nancnt = nancnt + 1 else: cumsum1 = cumsum1 + value1 cumsum2 = cumsum2 + value2 cumsum_prod = cumsum_prod + value1 * value2 if i < window: window_len = i + 1 - nancnt else: if np.isnan(pre_window_value1) or np.isnan(pre_window_value2): nancnt = nancnt - 1 else: cumsum1 = cumsum1 - pre_window_value1 cumsum2 = cumsum2 - pre_window_value2 cumsum_prod = cumsum_prod - pre_window_value1 * pre_window_value2 window_len = window - nancnt if window_len < minp or window_len == ddof: value = np.nan else: window_prod_mean = cumsum_prod / (window_len - ddof) window_mean1 = cumsum1 / window_len window_mean2 = cumsum2 / window_len window_mean_prod = window_mean1 * window_mean2 * window_len / (window_len - ddof) value = window_prod_mean - window_mean_prod return RollCovAOS( cumsum1=cumsum1, cumsum2=cumsum2, cumsum_prod=cumsum_prod, nancnt=nancnt, window_len=window_len, value=value, ) @register_jitted(cache=True) def rolling_cov_1d_nb( arr1: tp.Array1d, arr2: tp.Array1d, window: int, minp: tp.Optional[int] = None, ddof: int = 0, ) -> tp.Array1d: """Compute rolling covariance. Numba equivalent to `pd.Series(arr1).rolling(window, min_periods=minp).cov(arr2)`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr1, dtype=float_) cumsum1 = 0.0 cumsum2 = 0.0 cumsum_prod = 0.0 nancnt = 0 for i in range(arr1.shape[0]): in_state = RollCovAIS( i=i, value1=arr1[i], value2=arr2[i], pre_window_value1=arr1[i - window] if i - window >= 0 else np.nan, pre_window_value2=arr2[i - window] if i - window >= 0 else np.nan, cumsum1=cumsum1, cumsum2=cumsum2, cumsum_prod=cumsum_prod, nancnt=nancnt, window=window, minp=minp, ddof=ddof, ) out_state = rolling_cov_acc_nb(in_state) cumsum1 = out_state.cumsum1 cumsum2 = out_state.cumsum2 cumsum_prod = out_state.cumsum_prod nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), window=None, minp=None, ddof=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_cov_nb( arr1: tp.Array2d, arr2: tp.Array2d, window: int, minp: tp.Optional[int] = None, ddof: int = 0, ) -> tp.Array2d: """2-dim version of `rolling_cov_1d_nb`.""" out = np.empty_like(arr1, dtype=float_) for col in prange(arr1.shape[1]): out[:, col] = rolling_cov_1d_nb(arr1[:, col], arr2[:, col], window, minp=minp, ddof=ddof) return out @register_jitted(cache=True) def rolling_corr_acc_nb(in_state: RollCorrAIS) -> RollCorrAOS: """Accumulator of `rolling_corr_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollCorrAIS` and returns a state of type `vectorbtpro.generic.enums.RollCorrAOS`.""" i = in_state.i value1 = in_state.value1 value2 = in_state.value2 pre_window_value1 = in_state.pre_window_value1 pre_window_value2 = in_state.pre_window_value2 cumsum1 = in_state.cumsum1 cumsum2 = in_state.cumsum2 cumsum_sq1 = in_state.cumsum_sq1 cumsum_sq2 = in_state.cumsum_sq2 cumsum_prod = in_state.cumsum_prod nancnt = in_state.nancnt window = in_state.window minp = in_state.minp if np.isnan(value1) or np.isnan(value2): nancnt = nancnt + 1 else: cumsum1 = cumsum1 + value1 cumsum2 = cumsum2 + value2 cumsum_sq1 = cumsum_sq1 + value1**2 cumsum_sq2 = cumsum_sq2 + value2**2 cumsum_prod = cumsum_prod + value1 * value2 if i < window: window_len = i + 1 - nancnt else: if np.isnan(pre_window_value1) or np.isnan(pre_window_value2): nancnt = nancnt - 1 else: cumsum1 = cumsum1 - pre_window_value1 cumsum2 = cumsum2 - pre_window_value2 cumsum_sq1 = cumsum_sq1 - pre_window_value1**2 cumsum_sq2 = cumsum_sq2 - pre_window_value2**2 cumsum_prod = cumsum_prod - pre_window_value1 * pre_window_value2 window_len = window - nancnt if window_len < minp: value = np.nan else: nom = window_len * cumsum_prod - cumsum1 * cumsum2 denom1 = np.sqrt(window_len * cumsum_sq1 - cumsum1**2) denom2 = np.sqrt(window_len * cumsum_sq2 - cumsum2**2) denom = denom1 * denom2 if denom == 0: value = np.nan else: value = nom / denom return RollCorrAOS( cumsum1=cumsum1, cumsum2=cumsum2, cumsum_sq1=cumsum_sq1, cumsum_sq2=cumsum_sq2, cumsum_prod=cumsum_prod, nancnt=nancnt, window_len=window_len, value=value, ) @register_jitted(cache=True) def rolling_corr_1d_nb(arr1: tp.Array1d, arr2: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute rolling correlation coefficient. Numba equivalent to `pd.Series(arr1).rolling(window, min_periods=minp).corr(arr2)`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr1, dtype=float_) cumsum1 = 0.0 cumsum2 = 0.0 cumsum_sq1 = 0.0 cumsum_sq2 = 0.0 cumsum_prod = 0.0 nancnt = 0 for i in range(arr1.shape[0]): in_state = RollCorrAIS( i=i, value1=arr1[i], value2=arr2[i], pre_window_value1=arr1[i - window] if i - window >= 0 else np.nan, pre_window_value2=arr2[i - window] if i - window >= 0 else np.nan, cumsum1=cumsum1, cumsum2=cumsum2, cumsum_sq1=cumsum_sq1, cumsum_sq2=cumsum_sq2, cumsum_prod=cumsum_prod, nancnt=nancnt, window=window, minp=minp, ) out_state = rolling_corr_acc_nb(in_state) cumsum1 = out_state.cumsum1 cumsum2 = out_state.cumsum2 cumsum_sq1 = out_state.cumsum_sq1 cumsum_sq2 = out_state.cumsum_sq2 cumsum_prod = out_state.cumsum_prod nancnt = out_state.nancnt out[i] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_corr_nb(arr1: tp.Array2d, arr2: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `rolling_corr_1d_nb`.""" out = np.empty_like(arr1, dtype=float_) for col in prange(arr1.shape[1]): out[:, col] = rolling_corr_1d_nb(arr1[:, col], arr2[:, col], window, minp=minp) return out @register_jitted(cache=True) def rolling_ols_acc_nb(in_state: RollOLSAIS) -> RollOLSAOS: """Accumulator of `rolling_ols_1d_nb`. Takes a state of type `vectorbtpro.generic.enums.RollOLSAIS` and returns a state of type `vectorbtpro.generic.enums.RollOLSAOS`.""" i = in_state.i value1 = in_state.value1 value2 = in_state.value2 pre_window_value1 = in_state.pre_window_value1 pre_window_value2 = in_state.pre_window_value2 validcnt = in_state.validcnt cumsum1 = in_state.cumsum1 cumsum2 = in_state.cumsum2 cumsum_sq1 = in_state.cumsum_sq1 cumsum_prod = in_state.cumsum_prod nancnt = in_state.nancnt window = in_state.window minp = in_state.minp if np.isnan(value1) or np.isnan(value2): nancnt = nancnt + 1 else: validcnt = validcnt + 1 cumsum1 = cumsum1 + value1 cumsum2 = cumsum2 + value2 cumsum_sq1 = cumsum_sq1 + value1**2 cumsum_prod = cumsum_prod + value1 * value2 if i < window: window_len = i + 1 - nancnt else: if np.isnan(pre_window_value1) or np.isnan(pre_window_value2): nancnt = nancnt - 1 else: validcnt = validcnt - 1 cumsum1 = cumsum1 - pre_window_value1 cumsum2 = cumsum2 - pre_window_value2 cumsum_sq1 = cumsum_sq1 - pre_window_value1**2 cumsum_prod = cumsum_prod - pre_window_value1 * pre_window_value2 window_len = window - nancnt if window_len < minp: slope_value = np.nan intercept_value = np.nan else: slope_num = validcnt * cumsum_prod - cumsum1 * cumsum2 slope_denom = validcnt * cumsum_sq1 - cumsum1**2 if slope_denom != 0: slope_value = slope_num / slope_denom else: slope_value = np.nan intercept_num = cumsum2 - slope_value * cumsum1 intercept_denom = validcnt if intercept_denom != 0: intercept_value = intercept_num / intercept_denom else: intercept_value = np.nan return RollOLSAOS( validcnt=validcnt, cumsum1=cumsum1, cumsum2=cumsum2, cumsum_sq1=cumsum_sq1, cumsum_prod=cumsum_prod, nancnt=nancnt, window_len=window_len, slope_value=slope_value, intercept_value=intercept_value, ) @register_jitted(cache=True) def rolling_ols_1d_nb( arr1: tp.Array1d, arr2: tp.Array1d, window: int, minp: tp.Optional[int] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Compute rolling linear regression.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") slope_out = np.empty_like(arr1, dtype=float_) intercept_out = np.empty_like(arr1, dtype=float_) validcnt = 0 cumsum1 = 0.0 cumsum2 = 0.0 cumsum_sq1 = 0.0 cumsum_prod = 0.0 nancnt = 0 for i in range(arr1.shape[0]): in_state = RollOLSAIS( i=i, value1=arr1[i], value2=arr2[i], pre_window_value1=arr1[i - window] if i - window >= 0 else np.nan, pre_window_value2=arr2[i - window] if i - window >= 0 else np.nan, validcnt=validcnt, cumsum1=cumsum1, cumsum2=cumsum2, cumsum_sq1=cumsum_sq1, cumsum_prod=cumsum_prod, nancnt=nancnt, window=window, minp=minp, ) out_state = rolling_ols_acc_nb(in_state) validcnt = out_state.validcnt cumsum1 = out_state.cumsum1 cumsum2 = out_state.cumsum2 cumsum_sq1 = out_state.cumsum_sq1 cumsum_prod = out_state.cumsum_prod nancnt = out_state.nancnt slope_out[i] = out_state.slope_value intercept_out[i] = out_state.intercept_value return slope_out, intercept_out @register_chunkable( size=ch.ArraySizer(arg_query="arr1", axis=1), arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_ols_nb( arr1: tp.Array2d, arr2: tp.Array2d, window: int, minp: tp.Optional[int] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """2-dim version of `rolling_ols_1d_nb`.""" slope_out = np.empty_like(arr1, dtype=float_) intercept_out = np.empty_like(arr1, dtype=float_) for col in prange(arr1.shape[1]): slope_out[:, col], intercept_out[:, col] = rolling_ols_1d_nb(arr1[:, col], arr2[:, col], window, minp=minp) return slope_out, intercept_out @register_jitted(cache=True) def rolling_rank_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, pct: bool = False) -> tp.Array1d: """Rolling version of `rank_1d_nb`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) nancnt = 0 for i in range(arr.shape[0]): if np.isnan(arr[i]): nancnt = nancnt + 1 if i < window: valid_cnt = i + 1 - nancnt else: if np.isnan(arr[i - window]): nancnt = nancnt - 1 valid_cnt = window - nancnt if valid_cnt < minp: out[i] = np.nan else: from_i = max(0, i + 1 - window) to_i = i + 1 arr_window = arr[from_i:to_i] out[i] = rank_1d_nb(arr_window, pct=pct)[-1] return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, pct=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_rank_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, pct: bool = False) -> tp.Array2d: """2-dim version of `rolling_rank_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_rank_1d_nb(arr[:, col], window, minp=minp, pct=pct) return out @register_jitted(cache=True) def rolling_min_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute rolling min. Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).min()`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) for i in range(arr.shape[0]): from_i = max(i - window + 1, 0) to_i = i + 1 minv = arr[from_i] cnt = 0 for j in range(from_i, to_i): if np.isnan(arr[j]): continue if np.isnan(minv) or arr[j] < minv: minv = arr[j] cnt += 1 if cnt < minp: out[i] = np.nan else: out[i] = minv return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_min_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `rolling_min_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_min_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def rolling_max_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d: """Compute rolling max. Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).max()`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=float_) for i in range(arr.shape[0]): from_i = max(i - window + 1, 0) to_i = i + 1 maxv = arr[from_i] cnt = 0 for j in range(from_i, to_i): if np.isnan(arr[j]): continue if np.isnan(maxv) or arr[j] > maxv: maxv = arr[j] cnt += 1 if cnt < minp: out[i] = np.nan else: out[i] = maxv return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_max_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d: """2-dim version of `rolling_max_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_max_1d_nb(arr[:, col], window, minp=minp) return out @register_jitted(cache=True) def rolling_argmin_1d_nb( arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, local: bool = False, ) -> tp.Array1d: """Compute rolling min index.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=int_) for i in range(arr.shape[0]): from_i = max(i - window + 1, 0) to_i = i + 1 minv = arr[from_i] if local: mini = 0 else: mini = from_i cnt = 0 for k, j in enumerate(range(from_i, to_i)): if np.isnan(arr[j]): continue if np.isnan(minv) or arr[j] < minv: minv = arr[j] if local: mini = k else: mini = j cnt += 1 if cnt < minp: out[i] = -1 else: out[i] = mini return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, local=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_argmin_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, local: bool = False) -> tp.Array2d: """2-dim version of `rolling_argmin_1d_nb`.""" out = np.empty_like(arr, dtype=int_) for col in prange(arr.shape[1]): out[:, col] = rolling_argmin_1d_nb(arr[:, col], window, minp=minp, local=local) return out @register_jitted(cache=True) def rolling_argmax_1d_nb( arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, local: bool = False, ) -> tp.Array1d: """Compute rolling max index.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(arr, dtype=int_) for i in range(arr.shape[0]): from_i = max(i - window + 1, 0) to_i = i + 1 maxv = arr[from_i] if local: maxi = 0 else: maxi = from_i cnt = 0 for k, j in enumerate(range(from_i, to_i)): if np.isnan(arr[j]): continue if np.isnan(maxv) or arr[j] > maxv: maxv = arr[j] if local: maxi = k else: maxi = j cnt += 1 if cnt < minp: out[i] = -1 else: out[i] = maxi return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, local=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_argmax_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, local: bool = False) -> tp.Array2d: """2-dim version of `rolling_argmax_1d_nb`.""" out = np.empty_like(arr, dtype=int_) for col in prange(arr.shape[1]): out[:, col] = rolling_argmax_1d_nb(arr[:, col], window, minp=minp, local=local) return out @register_jitted(cache=True) def rolling_any_1d_nb(arr: tp.Array1d, window: int) -> tp.Array1d: """Compute rolling any.""" out = np.empty_like(arr, dtype=np.bool_) last_true_i = -1 for i in range(arr.shape[0]): if not np.isnan(arr[i]) and arr[i]: last_true_i = i from_i = max(0, i + 1 - window) if last_true_i >= from_i: out[i] = True else: out[i] = False return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_any_nb(arr: tp.Array2d, window: int) -> tp.Array2d: """2-dim version of `rolling_any_1d_nb`.""" out = np.empty_like(arr, dtype=np.bool_) for col in prange(arr.shape[1]): out[:, col] = rolling_any_1d_nb(arr[:, col], window) return out @register_jitted(cache=True) def rolling_all_1d_nb(arr: tp.Array1d, window: int) -> tp.Array1d: """Compute rolling all.""" out = np.empty_like(arr, dtype=np.bool_) last_false_i = -1 for i in range(arr.shape[0]): if np.isnan(arr[i]) or not arr[i]: last_false_i = i from_i = max(0, i + 1 - window) if last_false_i >= from_i: out[i] = False else: out[i] = True return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_all_nb(arr: tp.Array2d, window: int) -> tp.Array2d: """2-dim version of `rolling_all_1d_nb`.""" out = np.empty_like(arr, dtype=np.bool_) for col in prange(arr.shape[1]): out[:, col] = rolling_all_1d_nb(arr[:, col], window) return out @register_jitted(cache=True) def rolling_pattern_similarity_1d_nb( arr: tp.Array1d, pattern: tp.Array1d, window: tp.Optional[int] = None, max_window: tp.Optional[int] = None, row_select_prob: float = 1.0, window_select_prob: float = 1.0, interp_mode: int = InterpMode.Mixed, rescale_mode: int = RescaleMode.MinMax, vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: int = ErrorType.Absolute, distance_measure: int = DistanceMeasure.MAE, max_error: tp.FlexArray1dLike = np.nan, max_error_interp_mode: tp.Optional[int] = None, max_error_as_maxdist: bool = False, max_error_strict: bool = False, min_pct_change: float = np.nan, max_pct_change: float = np.nan, min_similarity: float = 0.85, minp: tp.Optional[int] = None, ) -> tp.Array1d: """Compute rolling pattern similarity. Uses `vectorbtpro.generic.nb.patterns.pattern_similarity_nb`.""" max_error_ = to_1d_array_nb(np.asarray(max_error)) if window is None: window = pattern.shape[0] if max_window is None: max_window = window out = np.full(arr.shape, np.nan, dtype=float_) min_max_required = False if rescale_mode == RescaleMode.MinMax: min_max_required = True if not np.isnan(min_pct_change): min_max_required = True if not np.isnan(max_pct_change): min_max_required = True if not max_error_as_maxdist: min_max_required = True if min_max_required: if np.isnan(pmin): pmin = np.nanmin(pattern) if np.isnan(pmax): pmax = np.nanmax(pattern) for i in range(arr.shape[0]): from_i = i - window + 1 to_i = i + 1 if from_i < 0: continue if np.random.uniform(0, 1) < row_select_prob: _vmin = vmin _vmax = vmax if min_max_required: if np.isnan(_vmin) or np.isnan(_vmax): for j in range(from_i, to_i): if np.isnan(_vmin) or arr[j] < _vmin: _vmin = arr[j] if np.isnan(_vmax) or arr[j] > _vmax: _vmax = arr[j] for w in range(window, max_window + 1): from_i = i - w + 1 to_i = i + 1 if from_i < 0: continue if min_max_required: if w > window: if arr[from_i] < _vmin: _vmin = arr[from_i] if arr[from_i] > _vmax: _vmax = arr[from_i] if np.random.uniform(0, 1) < window_select_prob: arr_window = arr[from_i:to_i] similarity = pattern_similarity_nb( arr_window, pattern, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=_vmin, vmax=_vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, distance_measure=distance_measure, max_error=max_error_, max_error_interp_mode=max_error_interp_mode, max_error_as_maxdist=max_error_as_maxdist, max_error_strict=max_error_strict, min_pct_change=min_pct_change, max_pct_change=max_pct_change, min_similarity=min_similarity, minp=minp, ) if not np.isnan(similarity): if not np.isnan(out[i]): if similarity > out[i]: out[i] = similarity else: out[i] = similarity return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), pattern=None, window=None, max_window=None, row_select_prob=None, window_select_prob=None, interp_mode=None, rescale_mode=None, vmin=None, vmax=None, pmin=None, pmax=None, invert=None, error_type=None, distance_measure=None, max_error=None, max_error_interp_mode=None, max_error_as_maxdist=None, max_error_strict=None, min_pct_change=None, max_pct_change=None, min_similarity=None, minp=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_pattern_similarity_nb( arr: tp.Array2d, pattern: tp.Array1d, window: tp.Optional[int] = None, max_window: tp.Optional[int] = None, row_select_prob: float = 1.0, window_select_prob: float = 1.0, interp_mode: int = InterpMode.Mixed, rescale_mode: int = RescaleMode.MinMax, vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: int = ErrorType.Absolute, distance_measure: int = DistanceMeasure.MAE, max_error: tp.FlexArray1dLike = np.nan, max_error_interp_mode: tp.Optional[int] = None, max_error_as_maxdist: bool = False, max_error_strict: bool = False, min_pct_change: float = np.nan, max_pct_change: float = np.nan, min_similarity: float = 0.85, minp: tp.Optional[int] = None, ) -> tp.Array2d: """2-dim version of `rolling_pattern_similarity_1d_nb`.""" max_error_ = to_1d_array_nb(np.asarray(max_error)) if window is None: window = pattern.shape[0] if max_window is None: max_window = window out = np.full(arr.shape, np.nan, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = rolling_pattern_similarity_1d_nb( arr[:, col], pattern, window=window, max_window=max_window, row_select_prob=row_select_prob, window_select_prob=window_select_prob, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=vmin, vmax=vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, distance_measure=distance_measure, max_error=max_error_, max_error_interp_mode=max_error_interp_mode, max_error_as_maxdist=max_error_as_maxdist, max_error_strict=max_error_strict, min_pct_change=min_pct_change, max_pct_change=max_pct_change, min_similarity=min_similarity, minp=minp, ) return out # ############# Expanding functions ############# # @register_jitted(cache=True) def expanding_min_1d_nb(arr: tp.Array1d, minp: int = 1) -> tp.Array1d: """Compute expanding min. Numba equivalent to `pd.Series(arr).expanding(min_periods=minp).min()`.""" out = np.empty_like(arr, dtype=float_) minv = arr[0] cnt = 0 for i in range(arr.shape[0]): if np.isnan(minv) or arr[i] < minv: minv = arr[i] if not np.isnan(arr[i]): cnt += 1 if cnt < minp: out[i] = np.nan else: out[i] = minv return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def expanding_min_nb(arr: tp.Array2d, minp: int = 1) -> tp.Array2d: """2-dim version of `expanding_min_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = expanding_min_1d_nb(arr[:, col], minp=minp) return out @register_jitted(cache=True) def expanding_max_1d_nb(arr: tp.Array1d, minp: int = 1) -> tp.Array1d: """Compute expanding max. Numba equivalent to `pd.Series(arr).expanding(min_periods=minp).max()`.""" out = np.empty_like(arr, dtype=float_) maxv = arr[0] cnt = 0 for i in range(arr.shape[0]): if np.isnan(maxv) or arr[i] > maxv: maxv = arr[i] if not np.isnan(arr[i]): cnt += 1 if cnt < minp: out[i] = np.nan else: out[i] = maxv return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), minp=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def expanding_max_nb(arr: tp.Array2d, minp: int = 1) -> tp.Array2d: """2-dim version of `expanding_max_1d_nb`.""" out = np.empty_like(arr, dtype=float_) for col in prange(arr.shape[1]): out[:, col] = expanding_max_1d_nb(arr[:, col], minp=minp) return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Generic Numba-compiled functions for simulation ranges. !!! warning Resolution is more flexible and may return None while preparation always returns NumPy arrays. Thus, use preparation, not resolution, in Numba-parallel workflows.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.registries.jit_registry import register_jitted @register_jitted(cache=True) def resolve_sim_start_nb( sim_shape: tp.Shape, sim_start: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Optional[tp.Array1d]: """Resolve simulation start.""" if sim_start is None: if allow_none: return None return np.full(sim_shape[1], 0, dtype=int_) sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_)) if not check_bounds and len(sim_start_) == sim_shape[1]: return sim_start_ sim_start_out = np.empty(sim_shape[1], dtype=int_) can_be_none = True for i in range(sim_shape[1]): _sim_start = flex_select_1d_pc_nb(sim_start_, i) if _sim_start < 0: _sim_start = sim_shape[0] + _sim_start elif _sim_start > sim_shape[0]: _sim_start = sim_shape[0] sim_start_out[i] = _sim_start if _sim_start != 0: can_be_none = False if allow_none and can_be_none: return None return sim_start_out @register_jitted(cache=True) def resolve_sim_end_nb( sim_shape: tp.Shape, sim_end: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Optional[tp.Array1d]: """Resolve simulation end.""" if sim_end is None: if allow_none: return None return np.full(sim_shape[1], sim_shape[0], dtype=int_) sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_)) if not check_bounds and len(sim_end_) == sim_shape[1]: return sim_end_ new_sim_end = np.empty(sim_shape[1], dtype=int_) can_be_none = True for i in range(sim_shape[1]): _sim_end = flex_select_1d_pc_nb(sim_end_, i) if _sim_end < 0: _sim_end = sim_shape[0] + _sim_end elif _sim_end > sim_shape[0]: _sim_end = sim_shape[0] new_sim_end[i] = _sim_end if _sim_end != sim_shape[0]: can_be_none = False if allow_none and can_be_none: return None return new_sim_end @register_jitted(cache=True) def resolve_sim_range_nb( sim_shape: tp.Shape, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Tuple[tp.Optional[tp.Array1d], tp.Optional[tp.Array1d]]: """Resolve simulation start and end.""" new_sim_start = resolve_sim_start_nb( sim_shape=sim_shape, sim_start=sim_start, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_end = resolve_sim_end_nb( sim_shape=sim_shape, sim_end=sim_end, allow_none=allow_none, check_bounds=check_bounds, ) return new_sim_start, new_sim_end @register_jitted(cache=True) def resolve_grouped_sim_start_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Optional[tp.Array1d]: """Resolve grouped simulation start.""" if sim_start is None: if allow_none: return None return np.full(len(group_lens), 0, dtype=int_) sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_)) if len(sim_start_) == len(group_lens): if not check_bounds: return sim_start_ return resolve_sim_start_nb( (target_shape[0], len(group_lens)), sim_start=sim_start_, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_start = np.empty(len(group_lens), dtype=int_) can_be_none = True group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] min_sim_start = target_shape[0] for col in range(from_col, to_col): _sim_start = flex_select_1d_pc_nb(sim_start_, col) if _sim_start < 0: _sim_start = target_shape[0] + _sim_start elif _sim_start > target_shape[0]: _sim_start = target_shape[0] if _sim_start < min_sim_start: min_sim_start = _sim_start new_sim_start[group] = min_sim_start if min_sim_start != 0: can_be_none = False if allow_none and can_be_none: return None return new_sim_start @register_jitted(cache=True) def resolve_grouped_sim_end_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_end: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Optional[tp.Array1d]: """Resolve grouped simulation end.""" if sim_end is None: if allow_none: return None return np.full(len(group_lens), target_shape[0], dtype=int_) sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_)) if len(sim_end_) == len(group_lens): if not check_bounds: return sim_end_ return resolve_sim_end_nb( (target_shape[0], len(group_lens)), sim_end=sim_end_, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_end = np.empty(len(group_lens), dtype=int_) can_be_none = True group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] max_sim_end = 0 for col in range(from_col, to_col): _sim_end = flex_select_1d_pc_nb(sim_end_, col) if _sim_end < 0: _sim_end = target_shape[0] + _sim_end elif _sim_end > target_shape[0]: _sim_end = target_shape[0] if _sim_end > max_sim_end: max_sim_end = _sim_end new_sim_end[group] = max_sim_end if max_sim_end != target_shape[0]: can_be_none = False if allow_none and can_be_none: return None return new_sim_end @register_jitted(cache=True) def resolve_grouped_sim_range_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Tuple[tp.Optional[tp.Array1d], tp.Optional[tp.Array1d]]: """Resolve grouped simulation start and end.""" new_sim_start = resolve_grouped_sim_start_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_end = resolve_grouped_sim_end_nb( target_shape=target_shape, group_lens=group_lens, sim_end=sim_end, allow_none=allow_none, check_bounds=check_bounds, ) return new_sim_start, new_sim_end @register_jitted(cache=True) def resolve_ungrouped_sim_start_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Optional[tp.Array1d]: """Resolve ungrouped simulation start.""" if sim_start is None: if allow_none: return None return np.full(target_shape[1], 0, dtype=int_) sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_)) if len(sim_start_) == target_shape[1]: if not check_bounds: return sim_start_ return resolve_sim_start_nb( target_shape, sim_start=sim_start_, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_start = np.empty(target_shape[1], dtype=int_) can_be_none = True group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] _sim_start = flex_select_1d_pc_nb(sim_start_, group) if _sim_start < 0: _sim_start = target_shape[0] + _sim_start elif _sim_start > target_shape[0]: _sim_start = target_shape[0] for col in range(from_col, to_col): new_sim_start[col] = _sim_start if _sim_start != 0: can_be_none = False if allow_none and can_be_none: return None return new_sim_start @register_jitted(cache=True) def resolve_ungrouped_sim_end_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_end: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Optional[tp.Array1d]: """Resolve ungrouped simulation end.""" if sim_end is None: if allow_none: return None return np.full(target_shape[1], target_shape[0], dtype=int_) sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_)) if len(sim_end_) == target_shape[1]: if not check_bounds: return sim_end_ return resolve_sim_end_nb( target_shape, sim_end=sim_end_, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_end = np.empty(target_shape[1], dtype=int_) can_be_none = True group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] _sim_end = flex_select_1d_pc_nb(sim_end_, group) if _sim_end < 0: _sim_end = target_shape[0] + _sim_end elif _sim_end > target_shape[0]: _sim_end = target_shape[0] for col in range(from_col, to_col): new_sim_end[col] = _sim_end if _sim_end != target_shape[0]: can_be_none = False if allow_none and can_be_none: return None return new_sim_end @register_jitted(cache=True) def resolve_ungrouped_sim_range_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, allow_none: bool = False, check_bounds: bool = True, ) -> tp.Tuple[tp.Optional[tp.Array1d], tp.Optional[tp.Array1d]]: """Resolve ungrouped simulation start and end.""" new_sim_start = resolve_ungrouped_sim_start_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start, allow_none=allow_none, check_bounds=check_bounds, ) new_sim_end = resolve_ungrouped_sim_end_nb( target_shape=target_shape, group_lens=group_lens, sim_end=sim_end, allow_none=allow_none, check_bounds=check_bounds, ) return new_sim_start, new_sim_end @register_jitted(cache=True) def prepare_sim_start_nb( sim_shape: tp.Shape, sim_start: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Array1d: """Prepare simulation start.""" if sim_start is None: return np.full(sim_shape[1], 0, dtype=int_) sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_)) if not check_bounds and len(sim_start_) == sim_shape[1]: return sim_start_ sim_start_out = np.empty(sim_shape[1], dtype=int_) for i in range(sim_shape[1]): _sim_start = flex_select_1d_pc_nb(sim_start_, i) if _sim_start < 0: _sim_start = sim_shape[0] + _sim_start elif _sim_start > sim_shape[0]: _sim_start = sim_shape[0] sim_start_out[i] = _sim_start return sim_start_out @register_jitted(cache=True) def prepare_sim_end_nb( sim_shape: tp.Shape, sim_end: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Array1d: """Prepare simulation end.""" if sim_end is None: return np.full(sim_shape[1], sim_shape[0], dtype=int_) sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_)) if not check_bounds and len(sim_end_) == sim_shape[1]: return sim_end_ new_sim_end = np.empty(sim_shape[1], dtype=int_) for i in range(sim_shape[1]): _sim_end = flex_select_1d_pc_nb(sim_end_, i) if _sim_end < 0: _sim_end = sim_shape[0] + _sim_end elif _sim_end > sim_shape[0]: _sim_end = sim_shape[0] new_sim_end[i] = _sim_end return new_sim_end @register_jitted(cache=True) def prepare_sim_range_nb( sim_shape: tp.Shape, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Prepare simulation start and end.""" new_sim_start = prepare_sim_start_nb( sim_shape=sim_shape, sim_start=sim_start, check_bounds=check_bounds, ) new_sim_end = prepare_sim_end_nb( sim_shape=sim_shape, sim_end=sim_end, check_bounds=check_bounds, ) return new_sim_start, new_sim_end @register_jitted(cache=True) def prepare_grouped_sim_start_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Array1d: """Prepare grouped simulation start.""" if sim_start is None: return np.full(len(group_lens), 0, dtype=int_) sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_)) if len(sim_start_) == len(group_lens): if not check_bounds: return sim_start_ return prepare_sim_start_nb( (target_shape[0], len(group_lens)), sim_start=sim_start_, check_bounds=check_bounds, ) new_sim_start = np.empty(len(group_lens), dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] min_sim_start = target_shape[0] for col in range(from_col, to_col): _sim_start = flex_select_1d_pc_nb(sim_start_, col) if _sim_start < 0: _sim_start = target_shape[0] + _sim_start elif _sim_start > target_shape[0]: _sim_start = target_shape[0] if _sim_start < min_sim_start: min_sim_start = _sim_start new_sim_start[group] = min_sim_start return new_sim_start @register_jitted(cache=True) def prepare_grouped_sim_end_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_end: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Array1d: """Prepare grouped simulation end.""" if sim_end is None: return np.full(len(group_lens), target_shape[0], dtype=int_) sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_)) if len(sim_end_) == len(group_lens): if not check_bounds: return sim_end_ return prepare_sim_end_nb( (target_shape[0], len(group_lens)), sim_end=sim_end_, check_bounds=check_bounds, ) new_sim_end = np.empty(len(group_lens), dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] max_sim_end = 0 for col in range(from_col, to_col): _sim_end = flex_select_1d_pc_nb(sim_end_, col) if _sim_end < 0: _sim_end = target_shape[0] + _sim_end elif _sim_end > target_shape[0]: _sim_end = target_shape[0] if _sim_end > max_sim_end: max_sim_end = _sim_end new_sim_end[group] = max_sim_end return new_sim_end @register_jitted(cache=True) def prepare_grouped_sim_range_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Prepare grouped simulation start and end.""" new_sim_start = prepare_grouped_sim_start_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start, check_bounds=check_bounds, ) new_sim_end = prepare_grouped_sim_end_nb( target_shape=target_shape, group_lens=group_lens, sim_end=sim_end, check_bounds=check_bounds, ) return new_sim_start, new_sim_end @register_jitted(cache=True) def prepare_ungrouped_sim_start_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Array1d: """Prepare ungrouped simulation start.""" if sim_start is None: return np.full(target_shape[1], 0, dtype=int_) sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_)) if len(sim_start_) == target_shape[1]: if not check_bounds: return sim_start_ return prepare_sim_start_nb( target_shape, sim_start=sim_start_, check_bounds=check_bounds, ) new_sim_start = np.empty(target_shape[1], dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] _sim_start = flex_select_1d_pc_nb(sim_start_, group) if _sim_start < 0: _sim_start = target_shape[0] + _sim_start elif _sim_start > target_shape[0]: _sim_start = target_shape[0] for col in range(from_col, to_col): new_sim_start[col] = _sim_start return new_sim_start @register_jitted(cache=True) def prepare_ungrouped_sim_end_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_end: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Array1d: """Prepare ungrouped simulation end.""" if sim_end is None: return np.full(target_shape[1], target_shape[0], dtype=int_) sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_)) if len(sim_end_) == target_shape[1]: if not check_bounds: return sim_end_ return prepare_sim_end_nb( target_shape, sim_end=sim_end_, check_bounds=check_bounds, ) new_sim_end = np.empty(target_shape[1], dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] _sim_end = flex_select_1d_pc_nb(sim_end_, group) if _sim_end < 0: _sim_end = target_shape[0] + _sim_end elif _sim_end > target_shape[0]: _sim_end = target_shape[0] for col in range(from_col, to_col): new_sim_end[col] = _sim_end return new_sim_end @register_jitted(cache=True) def prepare_ungrouped_sim_range_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, check_bounds: bool = True, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Prepare ungrouped simulation start and end.""" new_sim_start = prepare_ungrouped_sim_start_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start, check_bounds=check_bounds, ) new_sim_end = prepare_ungrouped_sim_end_nb( target_shape=target_shape, group_lens=group_lens, sim_end=sim_end, check_bounds=check_bounds, ) return new_sim_start, new_sim_end # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for splitting.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.generic.splitting.base import * from vectorbtpro.generic.splitting.decorators import * from vectorbtpro.generic.splitting.nb import * from vectorbtpro.generic.splitting.purged import * from vectorbtpro.generic.splitting.sklearn_ import * __import_if_installed__ = dict() __import_if_installed__["sklearn_"] = "sklearn" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for splitting.""" import inspect import math import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.accessors import BaseIDXAccessor from vectorbtpro.base.grouping.base import Grouper from vectorbtpro.base.indexes import combine_indexes, stack_indexes from vectorbtpro.base.indexing import hslice, PandasIndexer, get_index_ranges from vectorbtpro.base.merging import row_stack_merge, column_stack_merge, is_merge_func_from_config from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.reshaping import to_dict from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.generic.splitting import nb from vectorbtpro.generic.splitting.purged import BasePurgedCV, PurgedWalkForwardCV, PurgedKFoldCV from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.annotations import Annotatable, has_annotatables from vectorbtpro.utils.array_ import is_range from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING from vectorbtpro.utils.colors import adjust_opacity from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import hybrid_method from vectorbtpro.utils.eval_ import Evaluable from vectorbtpro.utils.execution import Task, NoResult, NoResultsException, filter_out_no_results, execute from vectorbtpro.utils.merging import parse_merge_func, MergeFunc from vectorbtpro.utils.parsing import ( get_func_arg_names, annotate_args, flatten_ann_args, unflatten_ann_args, ann_args_to_args, ) from vectorbtpro.utils.selection import PosSel, LabelSel from vectorbtpro.utils.template import CustomTemplate, Rep, RepFunc, substitute_templates from vectorbtpro.utils.warnings_ import warn if tp.TYPE_CHECKING: from sklearn.model_selection import BaseCrossValidator as BaseCrossValidatorT else: BaseCrossValidatorT = "BaseCrossValidator" __all__ = [ "FixRange", "RelRange", "Takeable", "Splitter", ] __pdoc__ = {} SplitterT = tp.TypeVar("SplitterT", bound="Splitter") @define class FixRange(DefineMixin): """Class that represents a fixed range.""" range_: tp.FixRangeLike = define.field() """Range.""" @define class RelRange(DefineMixin): """Class that represents a relative range.""" offset: tp.Union[int, float, tp.TimedeltaLike] = define.field(default=0) """Offset. Floating values between 0 and 1 are considered relative. Can be negative.""" offset_anchor: str = define.field(default="prev_end") """Offset anchor. Supported are * 'start': Start of the range * 'end': End of the range * 'prev_start': Start of the previous range * 'prev_end': End of the previous range """ offset_space: str = define.field(default="free") """Offset space. Supported are * 'all': All space * 'free': Remaining space after the offset anchor * 'prev': Length of the previous range Applied only when `RelRange.offset` is a relative number.""" length: tp.Union[int, float, tp.TimedeltaLike] = define.field(default=1.0) """Length. Floating values between 0 and 1 are considered relative. Can be negative.""" length_space: str = define.field(default="free") """Length space. Supported are * 'all': All space * 'free': Remaining space after the offset * 'free_or_prev': Remaining space after the offset or the start/end of the previous range, depending what comes first in the direction of `RelRange.length` Applied only when `RelRange.length` is a relative number.""" out_of_bounds: str = define.field(default="warn") """Check if start and stop are within bounds. Supported are * 'keep': Keep out-of-bounds values * 'ignore': Ignore if out-of-bounds * 'warn': Emit a warning if out-of-bounds * 'raise": Raise an error if out-of-bounds """ is_gap: bool = define.field(default=False) """Whether the range acts as a gap.""" def __attrs_post_init__(self): object.__setattr__(self, "offset_anchor", self.offset_anchor.lower()) if self.offset_anchor not in ("start", "end", "prev_start", "prev_end", "next_start", "next_end"): raise ValueError(f"Invalid offset_anchor: '{self.offset_anchor}'") object.__setattr__(self, "offset_space", self.offset_space.lower()) if self.offset_space not in ("all", "free", "prev"): raise ValueError(f"Invalid offset_space: '{self.offset_space}'") object.__setattr__(self, "length_space", self.length_space.lower()) if self.length_space not in ("all", "free", "free_or_prev"): raise ValueError(f"Invalid length_space: '{self.length_space}'") object.__setattr__(self, "out_of_bounds", self.out_of_bounds.lower()) if self.out_of_bounds not in ("keep", "ignore", "warn", "raise"): raise ValueError(f"Invalid out_of_bounds: '{self.out_of_bounds}'") def to_slice( self, total_len: int, prev_start: int = 0, prev_end: int = 0, index: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> slice: """Convert the relative range into a slice.""" if index is not None: index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) offset_anchor = self.offset_anchor offset = self.offset length = self.length if not checks.is_number(offset) or not checks.is_number(length): if not isinstance(index, pd.DatetimeIndex): raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}") if offset_anchor == "start": if checks.is_number(offset): offset_anchor = 0 else: offset_anchor = index[0] elif offset_anchor == "end": if checks.is_number(offset): offset_anchor = total_len else: if freq is None: raise ValueError("Must provide frequency") offset_anchor = index[-1] + freq elif offset_anchor == "prev_start": if checks.is_number(offset): offset_anchor = prev_start else: offset_anchor = index[prev_start] else: if checks.is_number(offset): offset_anchor = prev_end else: if prev_end < total_len: offset_anchor = index[prev_end] else: if freq is None: raise ValueError("Must provide frequency") offset_anchor = index[-1] + freq if checks.is_float(offset) and 0 <= abs(offset) <= 1: if self.offset_space == "all": offset = offset_anchor + int(offset * total_len) elif self.offset_space == "free": if offset < 0: offset = int((1 + offset) * offset_anchor) else: offset = offset_anchor + int(offset * (total_len - offset_anchor)) else: offset = offset_anchor + int(offset * (prev_end - prev_start)) else: if checks.is_float(offset): if not offset.is_integer(): raise TypeError(f"Floating number for offset ({offset}) must be between 0 and 1") offset = offset_anchor + int(offset) elif not checks.is_int(offset): offset = offset_anchor + dt.to_freq(offset) if index[0] <= offset <= index[-1]: offset = index.get_indexer([offset], method="ffill")[0] elif offset < index[0]: if freq is None: raise ValueError("Must provide frequency") offset = -int((index[0] - offset) / freq) else: if freq is None: raise ValueError("Must provide frequency") offset = total_len - 1 + int((offset - index[-1]) / freq) else: offset = offset_anchor + offset if checks.is_float(length) and 0 <= abs(length) <= 1: if self.length_space == "all": length = int(length * total_len) elif self.length_space == "free": if length < 0: length = int(length * offset) else: length = int(length * (total_len - offset)) else: if length < 0: if offset > prev_end: length = int(length * (offset - prev_end)) else: length = int(length * offset) else: if offset < prev_start: length = int(length * (prev_start - offset)) else: length = int(length * (total_len - offset)) else: if checks.is_float(length): if not length.is_integer(): raise TypeError(f"Floating number for length ({length}) must be between 0 and 1") length = int(length) elif not checks.is_int(length): length = dt.to_freq(length) start = offset if checks.is_int(length): stop = start + length else: if 0 <= start < total_len: stop = index[start] + length elif start < 0: if freq is None: raise ValueError("Must provide frequency") stop = index[0] + start * freq + length else: if freq is None: raise ValueError("Must provide frequency") stop = index[-1] + (start - total_len + 1) * freq + length if stop <= index[-1]: stop = index.get_indexer([stop], method="bfill")[0] else: if freq is None: raise ValueError("Must provide frequency") stop = total_len - 1 + int((stop - index[-1]) / freq) if checks.is_int(length): if length < 0: start, stop = stop, start else: if length < pd.Timedelta(0): start, stop = stop, start if start < 0: if self.out_of_bounds == "ignore": start = 0 elif self.out_of_bounds == "warn": warn(f"Range start ({start}) is out of bounds") start = 0 elif self.out_of_bounds == "raise": raise ValueError(f"Range start ({start}) is out of bounds") if stop > total_len: if self.out_of_bounds == "ignore": stop = total_len elif self.out_of_bounds == "warn": warn(f"Range stop ({stop}) is out of bounds") stop = total_len elif self.out_of_bounds == "raise": raise ValueError(f"Range stop ({stop}) is out of bounds") if stop - start <= 0: raise ValueError("Range length is negative or zero") return slice(start, stop) @define class Takeable(Evaluable, Annotatable, DefineMixin): """Class that represents an object from which a range can be taken.""" obj: tp.Any = define.required_field() """Takeable object.""" remap_to_obj: bool = define.optional_field() """Whether to remap `Splitter.index` to the index of `Takeable.obj`. Otherwise, will assume that the object has the same index.""" index: tp.Optional[tp.IndexLike] = define.optional_field() """Index of the object. If not present, will be accessed using `Splitter.get_obj_index`.""" freq: tp.Optional[tp.FrequencyLike] = define.optional_field() """Frequency of `Takeable.index`.""" point_wise: bool = define.optional_field() """Whether to select one range point at a time and return a tuple.""" eval_id: tp.Optional[tp.MaybeSequence[tp.Hashable]] = define.field(default=None) """One or more identifiers at which to evaluate this instance.""" class ZeroLengthError(ValueError): """Thrown whenever a range has a length of zero.""" pass class Splitter(Analyzable): """Base class for splitting.""" @classmethod def from_splits( cls: tp.Type[SplitterT], index: tp.IndexLike, splits: tp.Splits, squeeze: bool = False, fix_ranges: bool = True, wrap_with_fixrange: bool = False, split_range_kwargs: tp.KwargsLike = None, split_check_template: tp.Optional[tp.CustomTemplate] = None, template_context: tp.KwargsLike = None, split_labels: tp.Optional[tp.IndexLike] = None, set_labels: tp.Optional[tp.IndexLike] = None, wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from an iterable of splits. Argument `splits` supports both absolute and relative ranges. To transform relative ranges into the absolute format, enable `fix_ranges`. Arguments `split_range_kwargs` are then passed to `Splitter.split_range`. Enable `wrap_with_fixrange` to wrap any fixed range with `FixRange`. If the range is an array, it will be wrapped regardless of this argument to avoid building a 3d array. Pass a template via `split_check_template` to discard splits that do not fulfill certain criteria. The current split will be available as `split`. Should return a boolean (`False` to discard). Labels for splits and sets can be provided via `split_labels` and `set_labels` respectively. Both arguments can be provided as templates. The split array will be available as `splits`.""" index = dt.prepare_dt_index(index) if split_range_kwargs is None: split_range_kwargs = {} new_splits = [] removed_indices = [] for i, split in enumerate(splits): already_fixed = False if checks.is_number(split) or checks.is_td_like(split): split = cls.split_range( slice(None), split, template_context=template_context, index=index, wrap_with_fixrange=False, **split_range_kwargs, ) already_fixed = True new_split = split ndim = 2 elif cls.is_range_relative(split): new_split = [split] ndim = 1 elif not checks.is_sequence(split): new_split = [split] ndim = 1 elif isinstance(split, np.ndarray): new_split = [split] ndim = 1 else: new_split = split ndim = 2 if fix_ranges and not already_fixed: new_split = cls.split_range( slice(None), new_split, template_context=template_context, index=index, wrap_with_fixrange=False, **split_range_kwargs, ) _new_split = [] for range_ in new_split: if checks.is_number(range_) or checks.is_td_like(range_): range_ = RelRange(length=range_) if not isinstance(range_, (FixRange, RelRange)): if wrap_with_fixrange or checks.is_sequence(range_): _new_split.append(FixRange(range_)) else: _new_split.append(range_) else: _new_split.append(range_) if split_check_template is not None: _template_context = merge_dicts(dict(index=index, i=i, split=_new_split), template_context) split_ok = substitute_templates(split_check_template, _template_context, eval_id="split_check_template") if not split_ok: removed_indices.append(i) continue new_splits.append(_new_split) if len(new_splits) == 0: raise ValueError("Must provide at least one range") new_splits_arr = np.asarray(new_splits, dtype=object) if squeeze and new_splits_arr.shape[1] == 1: ndim = 1 if split_labels is None: split_labels = pd.RangeIndex(stop=new_splits_arr.shape[0], name="split") else: if isinstance(split_labels, CustomTemplate): _template_context = merge_dicts(dict(index=index, splits_arr=new_splits_arr), template_context) split_labels = substitute_templates(split_labels, _template_context, eval_id=split_labels) if not isinstance(split_labels, pd.Index): split_labels = pd.Index(split_labels, name="split") else: if not isinstance(split_labels, pd.Index): split_labels = pd.Index(split_labels, name="split") if len(removed_indices) > 0: split_labels = split_labels.delete(removed_indices) if set_labels is None: set_labels = pd.Index(["set_%d" % i for i in range(new_splits_arr.shape[1])], name="set") else: if isinstance(split_labels, CustomTemplate): _template_context = merge_dicts(dict(index=index, splits_arr=new_splits_arr), template_context) set_labels = substitute_templates(set_labels, _template_context, eval_id=set_labels) if not isinstance(set_labels, pd.Index): set_labels = pd.Index(set_labels, name="set") if wrapper_kwargs is None: wrapper_kwargs = {} wrapper = ArrayWrapper(index=split_labels, columns=set_labels, ndim=ndim, **wrapper_kwargs) return cls(wrapper, index, new_splits_arr, **kwargs) @classmethod def from_single( cls: tp.Type[SplitterT], index: tp.IndexLike, split: tp.Optional[tp.SplitLike], split_range_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a single split.""" if split_range_kwargs is None: split_range_kwargs = {} new_split = cls.split_range( slice(None), split, template_context=template_context, index=index, **split_range_kwargs, ) splits = [new_split] return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_rolling( cls: tp.Type[SplitterT], index: tp.IndexLike, length: tp.Union[int, float, tp.TimedeltaLike], offset: tp.Union[int, float, tp.TimedeltaLike] = 0, offset_anchor: str = "prev_end", offset_anchor_set: tp.Optional[int] = 0, offset_space: str = "prev", backwards: tp.Union[bool, str] = False, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, range_bounds_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a rolling range of a fixed length. Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance. Args: index (index_like): Index. length (int, float, or timedelta_like): See `RelRange.length`. offset (int, float, or timedelta_like): See `RelRange.offset`. offset_anchor (str): See `RelRange.offset_anchor`. offset_anchor_set (int): Offset anchor set. Selects the set from the previous range to be used as an offset anchor. If None, the whole previous split is considered as a single range. By default, it's the first set. offset_space (str): See `RelRange.offset_space`. backwards (bool or str): Whether to roll backwards. If 'sorted', will roll backwards and sort the resulting splits by the start index. split (any): Ranges to split the range into. If None, will produce the entire range as a single range. Otherwise, will use `Splitter.split_range` to split the range into multiple ranges. split_range_kwargs (dict): Keyword arguments passed to `Splitter.split_range`. range_bounds_kwargs (dict): Keyword arguments passed to `Splitter.get_range_bounds`. template_context (dict): Context used to substitute templates in ranges. freq (any): Index frequency in case it cannot be parsed from `index`. If None, will be parsed using `vectorbtpro.base.accessors.BaseIDXAccessor.get_freq`. **kwargs: Keyword arguments passed to the constructor of `Splitter`. Usage: * Divide a range into a set of non-overlapping ranges: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_rolling(index, 30) >>> splitter.plot().show() ``` ![](/assets/images/api/from_rolling_1.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_rolling_1.dark.svg#only-dark){: .iimg loading=lazy } * Divide a range into ranges, each split into 1/2: ```pycon >>> splitter = vbt.Splitter.from_rolling( ... index, ... 60, ... split=1/2, ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_rolling_2.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_rolling_2.dark.svg#only-dark){: .iimg loading=lazy } * Make the ranges above non-overlapping by using the right bound of the last set as an offset anchor: ```pycon >>> splitter = vbt.Splitter.from_rolling( ... index, ... 60, ... offset_anchor_set=-1, ... split=1/2, ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_rolling_3.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_rolling_3.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if isinstance(backwards, str): if backwards.lower() == "sorted": sort_backwards = True else: raise ValueError(f"Invalid backwards: '{backwards}'") backwards = True else: sort_backwards = False if split_range_kwargs is None: split_range_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq if range_bounds_kwargs is None: range_bounds_kwargs = {} splits = [] bounds = [] while True: if len(splits) == 0: new_split = RelRange( length=-length if backwards else length, offset_anchor="end" if backwards else "start", out_of_bounds="keep", ).to_slice(total_len=len(index), index=index, freq=freq) else: if offset_anchor_set is None: prev_start, prev_end = bounds[-1][0][0], bounds[-1][-1][1] else: prev_start, prev_end = bounds[-1][offset_anchor_set] new_split = RelRange( offset=offset, offset_anchor=offset_anchor, offset_space=offset_space, length=-length if backwards else length, length_space="all", out_of_bounds="keep", ).to_slice(total_len=len(index), prev_start=prev_start, prev_end=prev_end, index=index, freq=freq) if backwards: if new_split.stop >= bounds[-1][-1][1]: raise ValueError("Infinite loop detected. Provide a positive offset.") else: if new_split.start <= bounds[-1][0][0]: raise ValueError("Infinite loop detected. Provide a positive offset.") if backwards: if new_split.start < 0: break if new_split.stop > len(index): raise ValueError("Range stop cannot exceed index length") else: if new_split.start < 0: raise ValueError("Range start cannot be negative") if new_split.stop > len(index): break if split is not None: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) bounds.append( tuple( map( lambda x: cls.get_range_bounds( x, template_context=template_context, index=index, **range_bounds_kwargs, ), new_split, ) ) ) else: bounds.append(((new_split.start, new_split.stop),)) splits.append(new_split) return cls.from_splits( index, splits[::-1] if sort_backwards else splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_n_rolling( cls: tp.Type[SplitterT], index: tp.IndexLike, n: int, length: tp.Union[None, str, int, float, tp.TimedeltaLike] = None, optimize_anchor_set: int = 1, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a number of rolling ranges of the same length. If `length` is None, splits the index evenly into `n` non-overlapping ranges using `Splitter.from_rolling`. Otherwise, picks `n` evenly-spaced, potentially overlapping ranges of a fixed length. For other arguments, see `Splitter.from_rolling`. If `length` is "optimize", searches for a length to cover the most of the index. Use `optimize_anchor_set` to provide the index of a set that should become non-overlapping. Usage: * Roll 10 ranges with 100 elements, and split it into 3/4: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_n_rolling( ... index, ... 10, ... length=100, ... split=3/4, ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_n_rolling.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_n_rolling.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if split_range_kwargs is None: split_range_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq if length is None: return cls.from_rolling( index, length=len(index) // n, offset=0, offset_anchor="prev_end", offset_anchor_set=None, split=split, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) if isinstance(length, str) and length.lower() == "optimize": from scipy.optimize import minimize_scalar if split is not None and not checks.is_float(split): raise TypeError("Split must be a float when length='optimize'") checks.assert_in(optimize_anchor_set, (0, 1)) if split is None: ratio = 1.0 else: ratio = split def empty_len_objective(length): length = math.ceil(length) first_len = int(ratio * length) second_len = length - first_len if split is None or optimize_anchor_set == 0: empty_len = len(index) - (n * first_len + second_len) else: empty_len = len(index) - (n * second_len + first_len) if empty_len >= 0: return empty_len return len(index) length = math.ceil(minimize_scalar(empty_len_objective).x) if split is None or optimize_anchor_set == 0: offset = int(ratio * length) else: offset = length - int(ratio * length) return cls.from_rolling( index, length=length, offset=offset, offset_anchor="prev_start", offset_anchor_set=None, split=split, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) if checks.is_float(length): if 0 <= abs(length) <= 1: length = len(index) * length elif not length.is_integer(): raise TypeError("Floating number for length must be between 0 and 1") length = int(length) if checks.is_int(length): if length < 1 or length > len(index): raise TypeError(f"Length must be within [{1}, {len(index)}]") offsets = np.arange(len(index)) offsets = offsets[offsets + length <= len(index)] else: length = dt.to_freq(length) if freq is None: raise ValueError("Must provide freq") if length < freq or length > index[-1] + freq - index[0]: raise TypeError(f"Length must be within [{freq}, {index[-1] + freq - index[0]}]") offsets = index[index + length <= index[-1] + freq] - index[0] if n > len(offsets): n = len(offsets) rows = np.round(np.linspace(0, len(offsets) - 1, n)).astype(int) offsets = offsets[rows] splits = [] for offset in offsets: new_split = RelRange( offset=offset, length=length, ).to_slice(len(index), index=index, freq=freq) if split is not None: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) splits.append(new_split) return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_expanding( cls: tp.Type[SplitterT], index: tp.IndexLike, min_length: tp.Union[int, float, tp.TimedeltaLike], offset: tp.Union[int, float, tp.TimedeltaLike], split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, range_bounds_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from an expanding range. Argument `min_length` is the minimum length of the expanding range. Provide it as a float between 0 and 1 to make it relative to the length of the index. Argument `offset` is an offset after the right bound of the previous range from which the next range should start. It can also be a float relative to the index length. For other arguments, see `Splitter.from_rolling`. Usage: * Roll an expanding range with a length of 10 and an offset of 10, and split it into 3/4: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_expanding( ... index, ... 10, ... 10, ... split=3/4, ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_expanding.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_expanding.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if split_range_kwargs is None: split_range_kwargs = {} if range_bounds_kwargs is None: range_bounds_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq splits = [] bounds = [] while True: if len(splits) == 0: new_split = RelRange( length=min_length, out_of_bounds="keep", ).to_slice(total_len=len(index), index=index, freq=freq) else: prev_end = bounds[-1][-1][-1] new_split = RelRange( offset=offset, offset_anchor="prev_end", offset_space="all", length=-1.0, out_of_bounds="keep", ).to_slice(total_len=len(index), prev_end=prev_end, index=index, freq=freq) if new_split.stop <= prev_end: raise ValueError("Infinite loop detected. Provide a positive offset.") if new_split.start < 0: raise ValueError("Range start cannot be negative") if new_split.stop > len(index): break if split is not None: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) bounds.append( tuple( map( lambda x: cls.get_range_bounds( x, template_context=template_context, index=index, **range_bounds_kwargs, ), new_split, ) ) ) else: bounds.append(((new_split.start, new_split.stop),)) splits.append(new_split) return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_n_expanding( cls: tp.Type[SplitterT], index: tp.IndexLike, n: int, min_length: tp.Union[None, int, float, tp.TimedeltaLike] = None, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a number of expanding ranges. Picks `n` evenly-spaced, expanding ranges. Argument `min_length` defines the minimum length for each range. For other arguments, see `Splitter.from_rolling`. Usage: * Roll 10 expanding ranges with a minimum length of 100, while reserving 50 elements for test: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_n_expanding( ... index, ... 10, ... min_length=100, ... split=-50, ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_n_expanding.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_n_expanding.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if split_range_kwargs is None: split_range_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq if min_length is None: min_length = len(index) // n if checks.is_float(min_length): if 0 <= abs(min_length) <= 1: min_length = len(index) * min_length elif not min_length.is_integer(): raise TypeError("Floating number for minimum length must be between 0 and 1") if checks.is_int(min_length): min_length = int(min_length) if min_length < 1 or min_length > len(index): raise TypeError(f"Minimum length must be within [{1}, {len(index)}]") lengths = np.arange(1, len(index) + 1) lengths = lengths[lengths >= min_length] else: min_length = dt.to_freq(min_length) if freq is None: raise ValueError("Must provide freq") if min_length < freq or min_length > index[-1] + freq - index[0]: raise TypeError(f"Minimum length must be within [{freq}, {index[-1] + freq - index[0]}]") lengths = index[1:].append(index[[-1]] + freq) - index[0] lengths = lengths[lengths >= min_length] if n > len(lengths): n = len(lengths) rows = np.round(np.linspace(0, len(lengths) - 1, n)).astype(int) lengths = lengths[rows] splits = [] for length in lengths: new_split = RelRange(length=length).to_slice(len(index), index=index, freq=freq) if split is not None: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) splits.append(new_split) return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_ranges( cls: tp.Type[SplitterT], index: tp.IndexLike, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from ranges. Uses `vectorbtpro.base.indexing.get_index_ranges` to generate start and end indices. Passes only related keyword arguments found in `kwargs`. Other keyword arguments will be passed to `Splitter.from_splits`. For details on `split` and `split_range_kwargs`, see `Splitter.from_rolling`. Usage: * Translate each quarter into a range: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_ranges(index, every="QS") >>> splitter.plot().show() ``` ![](/assets/images/api/from_ranges_1.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_ranges_1.dark.svg#only-dark){: .iimg loading=lazy } * In addition to the above, reserve the last month for testing purposes: ```pycon >>> splitter = vbt.Splitter.from_ranges( ... index, ... every="QS", ... split=(1.0, lambda index: index.month == index.month[-1]), ... split_range_kwargs=dict(backwards=True) ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_ranges_2.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_ranges_2.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) if split_range_kwargs is None: split_range_kwargs = {} func_arg_names = get_func_arg_names(get_index_ranges) ranges_kwargs = dict() for k in list(kwargs.keys()): if k in func_arg_names: ranges_kwargs[k] = kwargs.pop(k) start_idxs, stop_idxs = get_index_ranges(index, skip_not_found=True, **ranges_kwargs) splits = [] for start, stop in zip(start_idxs, stop_idxs): new_split = slice(start, stop) if split is not None: try: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) except ZeroLengthError: continue splits.append(new_split) return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_grouper( cls: tp.Type[SplitterT], index: tp.IndexLike, by: tp.AnyGroupByLike, groupby_kwargs: tp.KwargsLike = None, grouper_kwargs: tp.KwargsLike = None, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, split_labels: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a grouper. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance. Usage: * Map each month into a range: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> def is_month_end(index, split): ... last_range = split[-1] ... return index[last_range][-1].is_month_end >>> splitter = vbt.Splitter.from_grouper( ... index, ... "M", ... split_check_template=vbt.RepFunc(is_month_end) ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_grouper.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_grouper.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if split_range_kwargs is None: split_range_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq if grouper_kwargs is None: grouper_kwargs = {} if isinstance(by, CustomTemplate): _template_context = merge_dicts(dict(index=index), template_context) by = substitute_templates(by, _template_context, eval_id="by") grouper = BaseIDXAccessor(index).get_grouper(by, groupby_kwargs=groupby_kwargs, **grouper_kwargs) splits = [] indices = [] for i, new_split in enumerate(grouper.iter_group_idxs()): if split is not None: try: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) except ZeroLengthError: continue else: new_split = [new_split] splits.append(new_split) indices.append(i) if split_labels is None: split_labels = grouper.get_index()[indices] return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, split_labels=split_labels, **kwargs, ) @classmethod def from_n_random( cls: tp.Type[SplitterT], index: tp.IndexLike, n: int, min_length: tp.Union[int, float, tp.TimedeltaLike], max_length: tp.Union[None, int, float, tp.TimedeltaLike] = None, min_start: tp.Union[None, int, float, tp.DatetimeLike] = None, max_end: tp.Union[None, int, float, tp.DatetimeLike] = None, length_choice_func: tp.Optional[tp.Callable] = None, start_choice_func: tp.Optional[tp.Callable] = None, length_p_func: tp.Optional[tp.Callable] = None, start_p_func: tp.Optional[tp.Callable] = None, seed: tp.Optional[int] = None, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a number of random ranges. Randomly picks the length of a range between `min_length` and `max_length` (including) using `length_choice_func`, which receives an array of possible values and selects one. It defaults to `numpy.random.Generator.choice`. Optional function `length_p_func` takes the same as `length_choice_func` and must return either None or probabilities. Randomly picks the start position of a range starting at `min_start` and ending at `max_end` (excluding) minus the chosen length using `start_choice_func`, which receives an array of possible values and selects one. It defaults to `numpy.random.Generator.choice`. Optional function `start_p_func` takes the same as `start_choice_func` and must return either None or probabilities. !!! note Each function must take two arguments: the iteration index and the array with possible values. For other arguments, see `Splitter.from_rolling`. Usage: * Generate 20 random ranges with a length from [40, 100], and split each into 3/4: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_n_random( ... index, ... 20, ... min_length=40, ... max_length=100, ... split=3/4, ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_n_random.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_n_random.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if split_range_kwargs is None: split_range_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq if min_start is None: min_start = 0 if min_start is not None: if checks.is_float(min_start): if 0 <= abs(min_start) <= 1: min_start = len(index) * min_start elif not min_start.is_integer(): raise TypeError("Floating number for minimum start must be between 0 and 1") if checks.is_float(min_start): min_start = int(min_start) if checks.is_int(min_start): if min_start < 0 or min_start > len(index) - 1: raise TypeError(f"Minimum start must be within [{0}, {len(index) - 1}]") else: if not isinstance(index, pd.DatetimeIndex): raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}") min_start = dt.try_align_dt_to_index(min_start, index) if not isinstance(min_start, pd.Timestamp): raise ValueError(f"Minimum start ({min_start}) could not be parsed") if min_start < index[0] or min_start > index[-1]: raise TypeError(f"Minimum start must be within [{index[0]}, {index[-1]}]") min_start = index.get_indexer([min_start], method="bfill")[0] if max_end is None: max_end = len(index) if checks.is_float(max_end): if 0 <= abs(max_end) <= 1: max_end = len(index) * max_end elif not max_end.is_integer(): raise TypeError("Floating number for maximum end must be between 0 and 1") if checks.is_float(max_end): max_end = int(max_end) if checks.is_int(max_end): if max_end < 1 or max_end > len(index): raise TypeError(f"Maximum end must be within [{1}, {len(index)}]") else: if not isinstance(index, pd.DatetimeIndex): raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}") max_end = dt.try_align_dt_to_index(max_end, index) if not isinstance(max_end, pd.Timestamp): raise ValueError(f"Maximum end ({max_end}) could not be parsed") if freq is None: raise ValueError("Must provide freq") if max_end < index[0] + freq or max_end > index[-1] + freq: raise TypeError(f"Maximum end must be within [{index[0] + freq}, {index[-1] + freq}]") if max_end > index[-1]: max_end = len(index) else: max_end = index.get_indexer([max_end], method="bfill")[0] space_len = max_end - min_start if not checks.is_number(min_length): index_min_start = index[min_start] if max_end < len(index): index_max_end = index[max_end] else: if freq is None: raise ValueError("Must provide freq") index_max_end = index[-1] + freq index_space_len = index_max_end - index_min_start else: index_min_start = None index_max_end = None index_space_len = None if checks.is_float(min_length): if 0 <= abs(min_length) <= 1: min_length = space_len * min_length elif not min_length.is_integer(): raise TypeError("Floating number for minimum length must be between 0 and 1") min_length = int(min_length) if checks.is_int(min_length): if min_length < 1 or min_length > space_len: raise TypeError(f"Minimum length must be within [{1}, {space_len}]") else: min_length = dt.to_freq(min_length) if freq is None: raise ValueError("Must provide freq") if min_length < freq or min_length > index_space_len: raise TypeError(f"Minimum length must be within [{freq}, {index_space_len}]") if max_length is not None: if checks.is_float(max_length): if 0 <= abs(max_length) <= 1: max_length = space_len * max_length elif not max_length.is_integer(): raise TypeError("Floating number for maximum length must be between 0 and 1") max_length = int(max_length) if checks.is_int(max_length): if max_length < min_length or max_length > space_len: raise TypeError(f"Maximum length must be within [{min_length}, {space_len}]") else: max_length = dt.to_freq(max_length) if freq is None: raise ValueError("Must provide freq") if max_length < min_length or max_length > index_space_len: raise TypeError(f"Maximum length must be within [{min_length}, {index_space_len}]") else: max_length = min_length rng = np.random.default_rng(seed=seed) if length_p_func is None: length_p_func = lambda i, x: None if start_p_func is None: start_p_func = lambda i, x: None if length_choice_func is None: length_choice_func = lambda i, x: rng.choice(x, p=length_p_func(i, x)) else: if seed is not None: np.random.seed(seed) if start_choice_func is None: start_choice_func = lambda i, x: rng.choice(x, p=start_p_func(i, x)) else: if seed is not None: np.random.seed(seed) if checks.is_int(min_length): length_space = np.arange(min_length, max_length + 1) else: if freq is None: raise ValueError("Must provide freq") length_space = np.arange(min_length // freq, max_length // freq + 1) * freq index_space = np.arange(len(index)) splits = [] for i in range(n): length = length_choice_func(i, length_space) if checks.is_int(length): start = start_choice_func(i, index_space[min_start : max_end - length + 1]) else: from_dt = index_min_start.to_datetime64() to_dt = index_max_end.to_datetime64() - length start = start_choice_func(i, index_space[(index.values >= from_dt) & (index.values <= to_dt)]) new_split = RelRange(offset=start, length=length).to_slice(len(index), index=index, freq=freq) if split is not None: new_split = cls.split_range( new_split, split, template_context=template_context, index=index, **split_range_kwargs, ) splits.append(new_split) return cls.from_splits( index, splits, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def from_sklearn( cls: tp.Type[SplitterT], index: tp.IndexLike, skl_splitter: BaseCrossValidatorT, groups: tp.Optional[tp.ArrayLike] = None, split_labels: tp.Optional[tp.IndexLike] = None, set_labels: tp.Optional[tp.IndexLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a scikit-learn's splitter. The splitter must be an instance of `sklearn.model_selection.BaseCrossValidator`. Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance.""" from sklearn.model_selection import BaseCrossValidator index = dt.prepare_dt_index(index) checks.assert_instance_of(skl_splitter, BaseCrossValidator) if set_labels is None: set_labels = ["train", "test"] indices_generator = skl_splitter.split(np.arange(len(index))[:, None], groups=groups) return cls.from_splits( index, list(indices_generator), split_labels=split_labels, set_labels=set_labels, **kwargs, ) @classmethod def from_purged( cls: tp.Type[SplitterT], index: tp.IndexLike, purged_splitter: BasePurgedCV, pred_times: tp.Union[None, tp.Index, tp.Series] = None, eval_times: tp.Union[None, tp.Index, tp.Series] = None, split_labels: tp.Optional[tp.IndexLike] = None, set_labels: tp.Optional[tp.IndexLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a purged splitter. The splitter must be an instance of `vectorbtpro.generic.splitting.purged.BasePurgedCV`. Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance.""" index = dt.prepare_dt_index(index) checks.assert_instance_of(purged_splitter, BasePurgedCV) if set_labels is None: set_labels = ["train", "test"] indices_generator = purged_splitter.split( pd.Series(np.arange(len(index)), index=index), pred_times=pred_times, eval_times=eval_times, ) return cls.from_splits( index, list(indices_generator), split_labels=split_labels, set_labels=set_labels, **kwargs, ) @classmethod def from_purged_walkforward( cls: tp.Type[SplitterT], index: tp.IndexLike, n_folds: int = 10, n_test_folds: int = 1, min_train_folds: int = 2, max_train_folds: tp.Optional[int] = None, split_by_time: bool = False, purge_td: tp.TimedeltaLike = 0, pred_times: tp.Union[None, tp.Index, tp.Series] = None, eval_times: tp.Union[None, tp.Index, tp.Series] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from `vectorbtpro.generic.splitting.purged.PurgedWalkForwardCV`. Keyword arguments are passed to `Splitter.from_purged`.""" index = dt.prepare_dt_index(index) purged_splitter = PurgedWalkForwardCV( n_folds=n_folds, n_test_folds=n_test_folds, min_train_folds=min_train_folds, max_train_folds=max_train_folds, split_by_time=split_by_time, purge_td=purge_td, ) return cls.from_purged( index, purged_splitter, pred_times=pred_times, eval_times=eval_times, **kwargs, ) @classmethod def from_purged_kfold( cls: tp.Type[SplitterT], index: tp.IndexLike, n_folds: int = 10, n_test_folds: int = 2, purge_td: tp.TimedeltaLike = 0, embargo_td: tp.TimedeltaLike = 0, pred_times: tp.Union[None, tp.Index, tp.Series] = None, eval_times: tp.Union[None, tp.Index, tp.Series] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from `vectorbtpro.generic.splitting.purged.PurgedKFoldCV`. Keyword arguments are passed to `Splitter.from_purged`.""" index = dt.prepare_dt_index(index) purged_splitter = PurgedKFoldCV( n_folds=n_folds, n_test_folds=n_test_folds, purge_td=purge_td, embargo_td=embargo_td, ) return cls.from_purged( index, purged_splitter, pred_times=pred_times, eval_times=eval_times, **kwargs, ) @classmethod def from_split_func( cls: tp.Type[SplitterT], index: tp.IndexLike, split_func: tp.Callable, split_args: tp.ArgsLike = None, split_kwargs: tp.KwargsLike = None, fix_ranges: bool = True, split: tp.Optional[tp.SplitLike] = None, split_range_kwargs: tp.KwargsLike = None, range_bounds_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs, ) -> SplitterT: """Create a `Splitter` instance from a custom split function. In a while-loop, substitutes templates in `split_args` and `split_kwargs` and passes them to `split_func`, which should return either a split (see `new_split` in `Splitter.split_range`, also supports a single range if it's not an iterable) or None to abrupt the while-loop. If `fix_ranges` is True, the returned split is then converted into a fixed split using `Splitter.split_range` and the bounds of its sets are measured using `Splitter.get_range_bounds`. Each template substitution has the following information: * `split_idx`: Current split index, starting at 0 * `splits`: Nested list of splits appended up to this point * `bounds`: Nested list of bounds appended up to this point * `prev_start`: Left bound of the previous split * `prev_end`: Right bound of the previous split * Arguments and keyword arguments passed to `Splitter.from_split_func` Usage: * Rolling window of 30 elements, 20 for train and 10 for test: ```pycon >>> from vectorbtpro import * >>> index = pd.date_range("2020", "2021", freq="D") >>> def split_func(splits, bounds, index): ... if len(splits) == 0: ... new_split = (slice(0, 20), slice(20, 30)) ... else: ... # Previous split, first set, right bound ... prev_end = bounds[-1][0][1] ... new_split = ( ... slice(prev_end, prev_end + 20), ... slice(prev_end + 20, prev_end + 30) ... ) ... if new_split[-1].stop > len(index): ... return None ... return new_split >>> splitter = vbt.Splitter.from_split_func( ... index, ... split_func, ... split_args=( ... vbt.Rep("splits"), ... vbt.Rep("bounds"), ... vbt.Rep("index"), ... ), ... set_labels=["train", "test"] ... ) >>> splitter.plot().show() ``` ![](/assets/images/api/from_split_func.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_split_func.dark.svg#only-dark){: .iimg loading=lazy } """ index = dt.prepare_dt_index(index) freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False) if split_range_kwargs is None: split_range_kwargs = {} if "freq" not in split_range_kwargs: split_range_kwargs = dict(split_range_kwargs) split_range_kwargs["freq"] = freq if range_bounds_kwargs is None: range_bounds_kwargs = {} if split_args is None: split_args = () if split_kwargs is None: split_kwargs = {} splits = [] bounds = [] split_idx = 0 n_sets = None while True: _template_context = merge_dicts( dict( split_idx=split_idx, splits=splits, bounds=bounds, prev_start=bounds[-1][0][0] if len(bounds) > 0 else None, prev_end=bounds[-1][-1][1] if len(bounds) > 0 else None, index=index, freq=freq, fix_ranges=fix_ranges, split_args=split_args, split_kwargs=split_kwargs, split_range_kwargs=split_range_kwargs, range_bounds_kwargs=range_bounds_kwargs, **kwargs, ), template_context, ) _split_func = substitute_templates(split_func, _template_context, eval_id="split_func") _split_args = substitute_templates(split_args, _template_context, eval_id="split_args") _split_kwargs = substitute_templates(split_kwargs, _template_context, eval_id="split_kwargs") new_split = _split_func(*_split_args, **_split_kwargs) if new_split is None: break if not checks.is_iterable(new_split): new_split = (new_split,) if fix_ranges or split is not None: new_split = cls.split_range( slice(None), new_split, template_context=_template_context, index=index, **split_range_kwargs, ) if split is not None: if len(new_split) > 1: raise ValueError("Split function must return only one range if split is already provided") new_split = cls.split_range( new_split[0], split, template_context=_template_context, index=index, **split_range_kwargs, ) if n_sets is None: n_sets = len(new_split) elif n_sets != len(new_split): raise ValueError("All splits must have the same number of sets") splits.append(new_split) if fix_ranges: split_bounds = tuple( map( lambda x: cls.get_range_bounds( x, template_context=_template_context, index=index, **range_bounds_kwargs, ), new_split, ) ) bounds.append(split_bounds) split_idx += 1 return cls.from_splits( index, splits, fix_ranges=fix_ranges, split_range_kwargs=split_range_kwargs, template_context=template_context, **kwargs, ) @classmethod def guess_method(cls, **kwargs) -> tp.Optional[str]: """Guess the factory method based on keyword arguments. Returns None if cannot guess.""" if len(kwargs) == 0: return None keys = {"index"} | set(kwargs.keys()) from_splits_arg_names = set(get_func_arg_names(cls.from_splits)) from_splits_arg_names.remove("splits") matched_methods = [] n_args = [] for k in cls.__dict__: if k.startswith("from_") and inspect.ismethod(getattr(cls, k)): req_func_arg_names = set(get_func_arg_names(getattr(cls, k), req_only=True)) if len(req_func_arg_names) > 0: if not (req_func_arg_names <= keys): continue opt_func_arg_names = set(get_func_arg_names(getattr(cls, k), opt_only=True)) func_arg_names = from_splits_arg_names | req_func_arg_names | opt_func_arg_names if k == "from_ranges": func_arg_names |= set(get_func_arg_names(get_index_ranges)) if len(func_arg_names) > 0: if not (keys <= func_arg_names): continue matched_methods.append(k) n_args.append(len(req_func_arg_names) + len(opt_func_arg_names)) if len(matched_methods) > 1: if "from_n_rolling" in matched_methods: return "from_n_rolling" return sorted(zip(matched_methods, n_args), key=lambda x: x[1])[0][0] if len(matched_methods) == 1: return matched_methods[0] return None @classmethod def split_and_take( cls, index: tp.IndexLike, obj: tp.Any, splitter: tp.Union[None, str, SplitterT, tp.Callable] = None, splitter_kwargs: tp.KwargsLike = None, take_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, _splitter_kwargs: tp.KwargsLike = None, _take_kwargs: tp.KwargsLike = None, **var_kwargs, ) -> tp.Any: """Split an index and take from an object. Argument `splitter` can be an actual `Splitter` instance, the name of a factory method (such as "from_n_rolling"), or the factory method itself. If `splitter` is None, the right method will be guessed based on the supplied arguments using `Splitter.guess_method`. Keyword arguments `splitter_kwargs` are passed to the factory method. Keyword arguments `take_kwargs` are passed to `Splitter.take`. If variable keyword arguments are provided, they will be used as `take_kwargs` if a splitter instance has been built, otherwise, arguments will be distributed based on the signatures of the factory method and `Splitter.take`.""" if splitter_kwargs is None: splitter_kwargs = {} else: splitter_kwargs = dict(splitter_kwargs) if take_kwargs is None: take_kwargs = {} else: take_kwargs = dict(take_kwargs) if _splitter_kwargs is None: _splitter_kwargs = {} if _take_kwargs is None: _take_kwargs = {} if len(var_kwargs) > 0: var_splitter_kwargs = {} var_take_kwargs = {} if splitter is None or not isinstance(splitter, cls): take_arg_names = get_func_arg_names(cls.take) if splitter is not None: if isinstance(splitter, str): splitter_arg_names = get_func_arg_names(getattr(cls, splitter)) else: splitter_arg_names = get_func_arg_names(splitter) for k, v in var_kwargs.items(): assigned = False if k in splitter_arg_names: var_splitter_kwargs[k] = v assigned = True if k != "split" and k in take_arg_names: var_take_kwargs[k] = v assigned = True if not assigned: raise ValueError(f"Argument '{k}' couldn't be assigned") else: for k, v in var_kwargs.items(): if k == "freq": var_splitter_kwargs[k] = v var_take_kwargs[k] = v elif k == "split" or k not in take_arg_names: var_splitter_kwargs[k] = v else: var_take_kwargs[k] = v else: var_take_kwargs = var_kwargs splitter_kwargs = merge_dicts(var_splitter_kwargs, splitter_kwargs) take_kwargs = merge_dicts(var_take_kwargs, take_kwargs) if len(splitter_kwargs) > 0: if splitter is None: splitter = cls.guess_method(**splitter_kwargs) if splitter is None: raise ValueError("Splitter method couldn't be guessed") else: if splitter is None: raise ValueError("Must provide splitter or splitter method") if not isinstance(splitter, cls): if isinstance(splitter, str): splitter = getattr(cls, splitter) for k, v in _splitter_kwargs.items(): if k not in splitter_kwargs: splitter_kwargs[k] = v splitter = splitter(index, template_context=template_context, **splitter_kwargs) for k, v in _take_kwargs.items(): if k not in take_kwargs: take_kwargs[k] = v return splitter.take(obj, template_context=template_context, **take_kwargs) @classmethod def split_and_apply( cls, index: tp.IndexLike, apply_func: tp.Callable, *apply_args, splitter: tp.Union[None, str, SplitterT, tp.Callable] = None, splitter_kwargs: tp.KwargsLike = None, apply_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, _splitter_kwargs: tp.KwargsLike = None, _apply_kwargs: tp.KwargsLike = None, **var_kwargs, ) -> tp.Any: """Split an index and apply a function. Argument `splitter` can be an actual `Splitter` instance, the name of a factory method (such as "from_n_rolling"), or the factory method itself. If `splitter` is None, the right method will be guessed based on the supplied arguments using `Splitter.guess_method`. Keyword arguments `splitter_kwargs` are passed to the factory method. Keyword arguments `apply_kwargs` are passed to `Splitter.apply`. If variable keyword arguments are provided, they will be used as `apply_kwargs` if a splitter instance has been built, otherwise, arguments will be distributed based on the signatures of the factory method and `Splitter.apply`.""" if splitter_kwargs is None: splitter_kwargs = {} else: splitter_kwargs = dict(splitter_kwargs) if apply_kwargs is None: apply_kwargs = {} else: apply_kwargs = dict(apply_kwargs) if _splitter_kwargs is None: _splitter_kwargs = {} if _apply_kwargs is None: _apply_kwargs = {} if len(var_kwargs) > 0: var_splitter_kwargs = {} var_apply_kwargs = {} if splitter is None or not isinstance(splitter, cls): apply_arg_names = get_func_arg_names(cls.apply) if splitter is not None: if isinstance(splitter, str): splitter_arg_names = get_func_arg_names(getattr(cls, splitter)) else: splitter_arg_names = get_func_arg_names(splitter) for k, v in var_kwargs.items(): assigned = False if k in splitter_arg_names: var_splitter_kwargs[k] = v assigned = True if k != "split" and k in apply_arg_names: var_apply_kwargs[k] = v assigned = True if not assigned: raise ValueError(f"Argument '{k}' couldn't be assigned") else: for k, v in var_kwargs.items(): if k == "freq": var_splitter_kwargs[k] = v var_apply_kwargs[k] = v elif k == "split" or k not in apply_arg_names: var_splitter_kwargs[k] = v else: var_apply_kwargs[k] = v else: var_apply_kwargs = var_kwargs splitter_kwargs = merge_dicts(var_splitter_kwargs, splitter_kwargs) apply_kwargs = merge_dicts(var_apply_kwargs, apply_kwargs) if len(splitter_kwargs) > 0: if splitter is None: splitter = cls.guess_method(**splitter_kwargs) if splitter is None: raise ValueError("Splitter method couldn't be guessed") else: if splitter is None: raise ValueError("Must provide splitter or splitter method") if not isinstance(splitter, cls): if isinstance(splitter, str): splitter = getattr(cls, splitter) for k, v in _splitter_kwargs.items(): if k not in splitter_kwargs: splitter_kwargs[k] = v splitter = splitter(index, template_context=template_context, **splitter_kwargs) for k, v in _apply_kwargs.items(): if k not in apply_kwargs: apply_kwargs[k] = v return splitter.apply(apply_func, *apply_args, template_context=template_context, **apply_kwargs) @classmethod def resolve_row_stack_kwargs( cls: tp.Type[SplitterT], *objs: tp.MaybeTuple[SplitterT], **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `Splitter` after stacking along rows.""" if "splits_arr" not in kwargs: kwargs["splits_arr"] = kwargs["wrapper"].row_stack_arrs( *[obj.splits for obj in objs], group_by=False, wrap=False, ) return kwargs @classmethod def resolve_column_stack_kwargs( cls: tp.Type[SplitterT], *objs: tp.MaybeTuple[SplitterT], reindex_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `Splitter` after stacking along columns.""" if "splits_arr" not in kwargs: kwargs["splits_arr"] = kwargs["wrapper"].column_stack_arrs( *[obj.splits for obj in objs], reindex_kwargs=reindex_kwargs, group_by=False, wrap=False, ) return kwargs @hybrid_method def row_stack( cls_or_self: tp.MaybeType[SplitterT], *objs: tp.MaybeTuple[SplitterT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> SplitterT: """Stack multiple `Splitter` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Splitter): raise TypeError("Each object to be merged must be an instance of Splitter") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack( *[obj.wrapper for obj in objs], stack_columns=False, **wrapper_kwargs, ) kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[SplitterT], *objs: tp.MaybeTuple[SplitterT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> SplitterT: """Stack multiple `Splitter` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Splitter): raise TypeError("Each object to be merged must be an instance of Splitter") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], union_index=False, **wrapper_kwargs, ) kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, index: tp.Index, splits_arr: tp.SplitsArray, **kwargs, ) -> None: if wrapper.grouper.is_grouped(): raise ValueError("Splitter cannot be grouped") index = dt.prepare_dt_index(index) if splits_arr.shape[0] != wrapper.shape_2d[0]: raise ValueError("Number of splits must match wrapper index") if splits_arr.shape[1] != wrapper.shape_2d[1]: raise ValueError("Number of sets must match wrapper columns") Analyzable.__init__( self, wrapper, index=index, splits_arr=splits_arr, **kwargs, ) self._index = index self._splits_arr = splits_arr def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict: """Perform indexing on `Splitter` and return metadata.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) if wrapper_meta["rows_changed"] or wrapper_meta["columns_changed"]: new_splits_arr = ArrayWrapper.select_from_flex_array( self.splits_arr, row_idxs=wrapper_meta["row_idxs"], col_idxs=wrapper_meta["col_idxs"], rows_changed=wrapper_meta["rows_changed"], columns_changed=wrapper_meta["columns_changed"], ) else: new_splits_arr = self.splits_arr return dict( wrapper_meta=wrapper_meta, new_splits_arr=new_splits_arr, ) def indexing_func(self: SplitterT, *args, splitter_meta: tp.DictLike = None, **kwargs) -> SplitterT: """Perform indexing on `Splitter`.""" if splitter_meta is None: splitter_meta = self.indexing_func_meta(*args, **kwargs) return self.replace( wrapper=splitter_meta["wrapper_meta"]["new_wrapper"], splits_arr=splitter_meta["new_splits_arr"], ) @property def index(self) -> tp.Index: """Index.""" return self._index @property def splits_arr(self) -> tp.SplitsArray: """Two-dimensional, object-dtype DataFrame with splits. First axis represents splits. Second axis represents sets. Elements represent ranges. Range must be either a slice, a sequence of indices, a mask, or a callable that returns such.""" return self._splits_arr @property def splits(self) -> tp.Frame: """`Splitter.splits_arr` as a DataFrame.""" return self.wrapper.wrap(self._splits_arr, group_by=False) @property def split_labels(self) -> tp.Index: """Split labels.""" return self.wrapper.index @property def set_labels(self) -> tp.Index: """Set labels.""" return self.wrapper.columns @property def n_splits(self) -> int: """Number of splits.""" return self.splits_arr.shape[0] @property def n_sets(self) -> int: """Number of sets.""" return self.splits_arr.shape[1] def get_split_grouper(self, split_group_by: tp.AnyGroupByLike = None) -> tp.Optional[Grouper]: """Get split grouper.""" if split_group_by is None: return None if isinstance(split_group_by, Grouper): return split_group_by return BaseIDXAccessor(self.split_labels).get_grouper(split_group_by, def_lvl_name="split_group") def get_set_grouper(self, set_group_by: tp.AnyGroupByLike = None) -> tp.Optional[Grouper]: """Get set grouper.""" if set_group_by is None: return None if isinstance(set_group_by, Grouper): return set_group_by return BaseIDXAccessor(self.set_labels).get_grouper(set_group_by, def_lvl_name="set_group") def get_n_splits(self, split_group_by: tp.AnyGroupByLike = None) -> int: """Get number of splits while considering the grouper.""" if split_group_by is not None: split_group_by = self.get_split_grouper(split_group_by=split_group_by) return split_group_by.get_group_count() return self.n_splits def get_n_sets(self, set_group_by: tp.AnyGroupByLike = None) -> int: """Get number of sets while considering the grouper.""" if set_group_by is not None: set_group_by = self.get_set_grouper(set_group_by=set_group_by) return set_group_by.get_group_count() return self.n_sets def get_split_labels(self, split_group_by: tp.AnyGroupByLike = None) -> tp.Index: """Get split labels while considering the grouper.""" if split_group_by is not None: split_group_by = self.get_split_grouper(split_group_by=split_group_by) return split_group_by.get_index() return self.split_labels def get_set_labels(self, set_group_by: tp.AnyGroupByLike = None) -> tp.Index: """Get set labels while considering the grouper.""" if set_group_by is not None: set_group_by = self.get_set_grouper(set_group_by=set_group_by) return set_group_by.get_index() return self.set_labels # ############# Conversion ############# # def to_fixed(self: SplitterT, split_range_kwargs: tp.KwargsLike = None, **kwargs) -> SplitterT: """Convert relative ranges into fixed ones and return a new `Splitter` instance. Keyword arguments `split_range_kwargs` are passed to `Splitter.split_range`.""" if split_range_kwargs is None: split_range_kwargs = {} split_range_kwargs = dict(split_range_kwargs) wrap_with_fixrange = split_range_kwargs.pop("wrap_with_fixrange", None) if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange: raise ValueError("Argument wrap_with_fixrange must be True or None") split_range_kwargs["wrap_with_fixrange"] = wrap_with_fixrange new_splits_arr = [] for split in self.splits_arr: new_split = self.split_range(slice(None), split, **split_range_kwargs) new_splits_arr.append(new_split) new_splits_arr = np.asarray(new_splits_arr, dtype=object) return self.replace(splits_arr=new_splits_arr, **kwargs) def to_grouped( self: SplitterT, split: tp.Optional[tp.Selection] = None, set_: tp.Optional[tp.Selection] = None, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, merge_split_kwargs: tp.KwargsLike = None, **kwargs, ) -> SplitterT: """Merge all ranges within the same group and return a new `Splitter` instance.""" if merge_split_kwargs is None: merge_split_kwargs = {} merge_split_kwargs = dict(merge_split_kwargs) wrap_with_fixrange = merge_split_kwargs.pop("wrap_with_fixrange", None) if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange: raise ValueError("Argument wrap_with_fixrange must be True or None") merge_split_kwargs["wrap_with_fixrange"] = wrap_with_fixrange split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) split_group_indices, set_group_indices, split_indices, set_indices = self.select_indices( split=split, set_=set_, split_group_by=split_group_by, set_group_by=set_group_by, ) if split is not None: split_labels = split_labels[split_group_indices] if set_ is not None: set_labels = set_labels[set_group_indices] new_splits_arr = [] for i in split_group_indices: new_splits_arr.append([]) for j in set_group_indices: new_range = self.select_range( split=PosSel(i), set_=PosSel(j), split_group_by=split_group_by, set_group_by=set_group_by, merge_split_kwargs=merge_split_kwargs, ) new_splits_arr[-1].append(new_range) new_splits_arr = np.asarray(new_splits_arr, dtype=object) if set_group_by is None or not set_group_by.is_grouped(): ndim = self.wrapper.ndim else: ndim = 1 if new_splits_arr.shape[1] == 1 else 2 wrapper = self.wrapper.replace(index=split_labels, columns=set_labels, ndim=ndim) return self.replace(wrapper=wrapper, splits_arr=new_splits_arr, **kwargs) # ############# Ranges ############# # @classmethod def is_range_relative(cls, range_: tp.RangeLike) -> bool: """Return whether a range is relative.""" return checks.is_number(range_) or checks.is_td_like(range_) or isinstance(range_, RelRange) @hybrid_method def get_ready_range( cls_or_self, range_: tp.FixRangeLike, allow_relative: bool = False, allow_zero_len: bool = False, range_format: str = "slice_or_any", template_context: tp.KwargsLike = None, index: tp.Optional[tp.IndexLike] = None, return_meta: bool = False, ) -> tp.Union[tp.RelRangeLike, tp.ReadyRangeLike, dict]: """Get a range that can be directly used in array indexing. Such a range is either an integer or datetime-like slice (right bound is always exclusive!), a one-dimensional NumPy array with integer indices or datetime-like objects, or a one-dimensional NumPy mask of the same length as the index. Argument `range_format` accepts the following options: * 'any': Return any format * 'indices': Return indices * 'mask': Return mask of the same length as index * 'slice': Return slice * 'slice_or_indices': If slice fails, return indices * 'slice_or_mask': If slice fails, return mask * 'slice_or_any': If slice fails, return any format """ if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) if range_format.lower() not in ( "any", "indices", "mask", "slice", "slice_or_indices", "slice_or_mask", "slice_or_any", ): raise ValueError(f"Invalid range_format: '{range_format}'") meta = dict() meta["was_fixed"] = False meta["was_template"] = False meta["was_callable"] = False meta["was_relative"] = False meta["was_hslice"] = False meta["was_slice"] = False meta["was_neg_slice"] = False meta["was_datetime"] = False meta["was_mask"] = False meta["was_indices"] = False meta["is_constant"] = False meta["start"] = None meta["stop"] = None meta["length"] = None if isinstance(range_, FixRange): meta["was_fixed"] = True range_ = range_.range_ if isinstance(range_, CustomTemplate): meta["was_template"] = True if template_context is None: template_context = {} if "index" not in template_context: template_context["index"] = index range_ = range_.substitute(context=template_context, eval_id="range") if callable(range_): meta["was_callable"] = True range_ = range_(index) if cls_or_self.is_range_relative(range_): meta["was_relative"] = True if allow_relative: if return_meta: meta["range_"] = range_ return meta return range_ raise TypeError("Relative ranges must be converted to fixed") if isinstance(range_, hslice): meta["was_hslice"] = True range_ = range_.to_slice() if isinstance(range_, slice): meta["was_slice"] = True meta["is_constant"] = True start = range_.start stop = range_.stop if range_.step is not None and range_.step > 1: raise ValueError("Step must be either None or 1") if start is not None and checks.is_int(start) and start < 0: if stop is not None and checks.is_int(stop) and stop > 0: raise ValueError("Slices must be either strictly negative or positive") meta["was_neg_slice"] = True start = len(index) + start if stop is not None and checks.is_int(stop): stop = len(index) + stop if start is None: start = 0 if stop is None: stop = len(index) if not checks.is_int(start): if not isinstance(index, pd.DatetimeIndex): raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}") start = dt.try_align_dt_to_index(start, index) if not isinstance(start, pd.Timestamp): raise ValueError(f"Range start ({start}) could not be parsed") meta["was_datetime"] = True if not checks.is_int(stop): if not isinstance(index, pd.DatetimeIndex): raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}") stop = dt.try_align_dt_to_index(stop, index) if not isinstance(stop, pd.Timestamp): raise ValueError(f"Range start ({stop}) could not be parsed") meta["was_datetime"] = True if checks.is_int(start): if start < 0: start = 0 else: if start < index[0]: start = 0 else: start = index.get_indexer([start], method="bfill")[0] if start == -1: raise ValueError(f"Range start ({start}) is out of bounds") if checks.is_int(stop): if stop > len(index): stop = len(index) else: if stop > index[-1]: stop = len(index) else: stop = index.get_indexer([stop], method="bfill")[0] if stop == -1: raise ValueError(f"Range stop ({stop}) is out of bounds") range_ = slice(start, stop) meta["start"] = start meta["stop"] = stop meta["length"] = stop - start if not allow_zero_len and meta["length"] == 0: raise ZeroLengthError("Range has zero length") if range_format.lower() == "indices": range_ = np.arange(*range_.indices(len(index))) elif range_format.lower() == "mask": mask = np.full(len(index), False) mask[range_] = True range_ = mask else: range_ = np.asarray(range_) if np.issubdtype(range_.dtype, np.bool_): if len(range_) != len(index): raise ValueError("Mask must have the same length as index") meta["was_mask"] = True indices = np.flatnonzero(range_) if len(indices) == 0: if not allow_zero_len: raise ZeroLengthError("Range has zero length") meta["is_constant"] = True meta["start"] = 0 meta["stop"] = 0 meta["length"] = 0 else: meta["is_constant"] = is_range(indices) meta["start"] = indices[0] meta["stop"] = indices[-1] + 1 meta["length"] = len(indices) if range_format.lower() == "indices": range_ = indices elif range_format.lower().startswith("slice"): if not meta["is_constant"]: if range_format.lower() == "slice": raise ValueError("Cannot convert to slice: range is not constant") if range_format.lower() == "slice_or_indices": range_ = indices else: range_ = slice(meta["start"], meta["stop"]) else: if not np.issubdtype(range_.dtype, np.integer): range_ = dt.try_align_to_dt_index(range_, index) if not isinstance(range_, pd.DatetimeIndex): raise ValueError("Range array could not be parsed") range_ = index.get_indexer(range_, method=None) if -1 in range_: raise ValueError(f"Range array has values that cannot be found in index") if np.issubdtype(range_.dtype, np.integer): meta["was_indices"] = True if len(range_) == 0: if not allow_zero_len: raise ZeroLengthError("Range has zero length") meta["is_constant"] = True meta["start"] = 0 meta["stop"] = 0 meta["length"] = 0 else: meta["is_constant"] = is_range(range_) if meta["is_constant"]: meta["start"] = range_[0] meta["stop"] = range_[-1] + 1 else: meta["start"] = np.min(range_) meta["stop"] = np.max(range_) + 1 meta["length"] = len(range_) if range_format.lower() == "mask": mask = np.full(len(index), False) mask[range_] = True range_ = mask elif range_format.lower().startswith("slice"): if not meta["is_constant"]: if range_format.lower() == "slice": raise ValueError("Cannot convert to slice: range is not constant") if range_format.lower() == "slice_or_mask": mask = np.full(len(index), False) mask[range_] = True range_ = mask else: range_ = slice(meta["start"], meta["stop"]) else: raise TypeError(f"Range array has invalid data type ({range_.dtype})") if meta["start"] != meta["stop"]: if meta["start"] > meta["stop"]: raise ValueError(f"Range start ({meta['start']}) is higher than range stop ({meta['stop']})") if meta["start"] < 0 or meta["start"] >= len(index): raise ValueError(f"Range start ({meta['start']}) is out of bounds") if meta["stop"] < 0 or meta["stop"] > len(index): raise ValueError(f"Range stop ({meta['stop']}) is out of bounds") if return_meta: meta["range_"] = range_ return meta return range_ @hybrid_method def split_range( cls_or_self, range_: tp.FixRangeLike, new_split: tp.SplitLike, backwards: bool = False, allow_zero_len: bool = False, range_format: tp.Optional[str] = None, wrap_with_template: bool = False, wrap_with_fixrange: tp.Optional[bool] = False, template_context: tp.KwargsLike = None, index: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.FixSplit: """Split a fixed range into a split of multiple fixed ranges. Range must be either a template, a callable, a tuple (start and stop), a slice, a sequence of indices, or a mask. This range will then be re-mapped into the index. Each sub-range in `new_split` can be either a fixed or relative range, that is, an instance of `RelRange` or a number that will be used as a length to create an `RelRange`. Each sub-range will then be re-mapped into the main range. Argument `new_split` can also be provided as an integer or a float indicating the length; in such a case the second part (or the first one depending on `backwards`) will stretch. If `new_split` is a string, the following options are supported: * 'by_gap': Split `range_` by gap using `vectorbtpro.generic.splitting.nb.split_range_by_gap_nb`. New ranges are returned relative to the index and in the same order as passed. For `range_format`, see `Splitter.get_ready_range`. Enable `wrap_with_template` to wrap the resulting ranges with a template of the type `vectorbtpro.utils.template.Rep`.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) # Prepare source range range_meta = cls_or_self.get_ready_range( range_, allow_zero_len=allow_zero_len, range_format="slice_or_indices", template_context=template_context, index=index, return_meta=True, ) range_ = range_meta["range_"] range_was_hslice = range_meta["was_hslice"] range_was_indices = range_meta["was_indices"] range_was_mask = range_meta["was_mask"] range_length = range_meta["length"] if range_format is None: if range_was_indices: range_format = "slice_or_indices" elif range_was_mask: range_format = "slice_or_mask" else: range_format = "slice_or_any" # Substitute template if isinstance(new_split, CustomTemplate): _template_context = merge_dicts(dict(index=index[range_]), template_context) new_split = substitute_templates(new_split, _template_context, eval_id="new_split") # Split by gap if isinstance(new_split, str) and new_split.lower() == "by_gap": if isinstance(range_, np.ndarray) and np.issubdtype(range_.dtype, np.integer): range_arr = range_ else: range_arr = np.arange(len(index))[range_] start_idxs, stop_idxs = nb.split_range_by_gap_nb(range_arr) new_split = list(map(lambda x: slice(x[0], x[1]), zip(start_idxs, stop_idxs))) # Prepare target ranges if checks.is_number(new_split): if new_split < 0: backwards = not backwards new_split = abs(new_split) if not backwards: new_split = (new_split, 1.0) else: new_split = (1.0, new_split) elif checks.is_td_like(new_split): new_split = dt.to_freq(new_split) if new_split < pd.Timedelta(0): backwards = not backwards new_split = abs(new_split) if not backwards: new_split = (new_split, 1.0) else: new_split = (1.0, new_split) elif not checks.is_iterable(new_split): new_split = (new_split,) # Perform split new_ranges = [] if backwards: new_split = new_split[::-1] prev_start = range_length prev_end = range_length else: prev_start = 0 prev_end = 0 for new_range in new_split: # Resolve new range new_range_meta = cls_or_self.get_ready_range( new_range, allow_relative=True, allow_zero_len=allow_zero_len, range_format="slice_or_any", template_context=template_context, index=index[range_], return_meta=True, ) new_range = new_range_meta["range_"] if checks.is_number(new_range) or checks.is_td_like(new_range): new_range = RelRange(length=new_range) if isinstance(new_range, RelRange): new_range_is_gap = new_range.is_gap new_range = new_range.to_slice( range_length, prev_start=range_length - prev_end if backwards else prev_start, prev_end=range_length - prev_start if backwards else prev_end, index=index, freq=freq, ) if backwards: new_range = slice(range_length - new_range.stop, range_length - new_range.start) else: new_range_is_gap = False # Update previous bounds if isinstance(new_range, slice): prev_start = new_range.start prev_end = new_range.stop else: prev_start = new_range_meta["start"] prev_end = new_range_meta["stop"] # Remap new range to index if new_range_is_gap: continue if isinstance(range_, slice) and isinstance(new_range, slice): new_range = slice( range_.start + new_range.start, range_.start + new_range.stop, ) else: if isinstance(range_, slice): new_range = np.arange(range_.start, range_.stop)[new_range] else: new_range = range_[new_range] new_range = cls_or_self.get_ready_range( new_range, allow_zero_len=allow_zero_len, range_format=range_format, index=index, ) if isinstance(new_range, slice) and range_was_hslice: new_range = hslice.from_slice(new_range) if wrap_with_template: new_range = Rep("range_", context=dict(range_=new_range)) if wrap_with_fixrange is None: _wrap_with_fixrange = checks.is_sequence(new_range) else: _wrap_with_fixrange = False if _wrap_with_fixrange: new_range = FixRange(new_range) new_ranges.append(new_range) if backwards: return tuple(new_ranges)[::-1] return tuple(new_ranges) @hybrid_method def merge_split( cls_or_self, split: tp.FixSplit, range_format: tp.Optional[str] = None, wrap_with_template: bool = False, wrap_with_fixrange: tp.Optional[bool] = False, wrap_with_hslice: tp.Optional[bool] = False, template_context: tp.KwargsLike = None, index: tp.Optional[tp.IndexLike] = None, ) -> tp.FixRangeLike: """Merge a split of multiple fixed ranges into a fixed range. Creates one mask and sets True for each range. If all input ranges are masks, returns that mask. If all input ranges are slices, returns a slice if possible. Otherwise, returns integer indices. For `range_format`, see `Splitter.get_ready_range`. Enable `wrap_with_template` to wrap the resulting range with a template of the type `vectorbtpro.utils.template.Rep`.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) all_hslices = True all_masks = True new_ranges = [] if len(split) == 1: raise ValueError("Two or more ranges are required to be merged") for range_ in split: range_meta = cls_or_self.get_ready_range( range_, allow_zero_len=True, range_format="any", template_context=template_context, index=index, return_meta=True, ) if not range_meta["was_hslice"]: all_hslices = False if not range_meta["was_mask"]: all_masks = False new_ranges.append(range_meta["range_"]) ranges = new_ranges if range_format is None: if all_masks: range_format = "slice_or_mask" else: range_format = "slice_or_indices" new_range = np.full(len(index), False) for range_ in ranges: new_range[range_] = True new_range = cls_or_self.get_ready_range( new_range, range_format=range_format, index=index, ) if isinstance(new_range, slice) and all_hslices: if wrap_with_hslice is None: wrap_with_hslice = True if wrap_with_hslice: new_range = hslice.from_slice(new_range) if wrap_with_template: new_range = Rep("range_", context=dict(range_=new_range)) if wrap_with_fixrange is None: _wrap_with_fixrange = checks.is_sequence(new_range) else: _wrap_with_fixrange = False if _wrap_with_fixrange: new_range = FixRange(new_range) return new_range # ############# Taking ############# # def select_indices( self, split: tp.Optional[tp.Selection] = None, set_: tp.Optional[tp.Selection] = None, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Get indices corresponding to selected splits and sets. Arguments `split` and `set_` can be either integers and labels. Also, multiple values are accepted; in such a case, the corresponding ranges are merged. If split/set labels are of an integer data type, treats the provided values as labels rather than indices, unless the split/set index is not of an integer data type or the values are wrapped with `vectorbtpro.utils.selection.PosSel`. If `split_group_by` and/or `set_group_by` are provided, their groupers get created using `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper` and arguments `split` and `set_` become relative to the groups. If `split`/`set_` is not provided, selects all indices. Returns four arrays: split group indices, set group indices, split indices, and set indices.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) if split is None: split_group_indices = np.arange(self.get_n_splits(split_group_by=split_group_by)) split_indices = np.arange(self.n_splits) else: kind = None if isinstance(split, PosSel): split = split.value kind = "positions" elif isinstance(split, LabelSel): split = split.value kind = "labels" if checks.is_hashable(split): split = [split] if split_group_by is not None: split_group_indices = [] groups, group_index = split_group_by.get_groups_and_index() mask = None for g in split: if isinstance(g, PosSel): g = g.value kind = "positions" elif isinstance(g, LabelSel): g = g.value kind = "labels" if kind == "positions" or ( kind is None and checks.is_int(g) and not pd.api.types.is_integer_dtype(group_index) ): i = g else: i = group_index.get_indexer([g])[0] if i == -1: raise ValueError(f"Split group '{g}' not found") if mask is None: mask = groups == i else: mask |= groups == i split_group_indices.append(i) split_group_indices = np.asarray(split_group_indices) split_indices = np.arange(self.n_splits)[mask] else: split_indices = [] for s in split: if isinstance(s, PosSel): s = s.value kind = "positions" elif isinstance(s, LabelSel): s = s.value kind = "labels" if kind == "positions" or ( kind is None and checks.is_int(s) and not pd.api.types.is_integer_dtype(self.split_labels) ): i = s else: i = self.split_labels.get_indexer([s])[0] if i == -1: raise ValueError(f"Split '{s}' not found") split_indices.append(i) split_group_indices = split_indices = np.asarray(split_indices) if set_ is None: set_group_indices = np.arange(self.get_n_sets(set_group_by=set_group_by)) set_indices = np.arange(self.n_sets) else: kind = None if isinstance(set_, PosSel): set_ = set_.value kind = "positions" elif isinstance(set_, LabelSel): set_ = set_.value kind = "labels" if checks.is_hashable(set_): set_ = [set_] if set_group_by is not None: set_group_indices = [] groups, group_index = set_group_by.get_groups_and_index() mask = None for g in set_: if isinstance(g, PosSel): g = g.value kind = "positions" elif isinstance(g, LabelSel): g = g.value kind = "labels" if kind == "positions" or ( kind is None and checks.is_int(g) and not pd.api.types.is_integer_dtype(group_index) ): i = g else: i = group_index.get_indexer([g])[0] if i == -1: raise ValueError(f"Set group '{g}' not found") if mask is None: mask = groups == i else: mask |= groups == i set_group_indices.append(i) set_group_indices = np.asarray(set_group_indices) set_indices = np.arange(self.n_sets)[mask] else: set_indices = [] for s in set_: if isinstance(s, PosSel): s = s.value kind = "positions" elif isinstance(s, LabelSel): s = s.value kind = "labels" if kind == "positions" or ( kind is None and checks.is_int(s) and not pd.api.types.is_integer_dtype(self.set_labels) ): i = s else: i = self.set_labels.get_indexer([s])[0] if i == -1: raise ValueError(f"Set '{s}' not found") set_indices.append(i) set_group_indices = set_indices = np.asarray(set_indices) return split_group_indices, set_group_indices, split_indices, set_indices def select_range(self, merge_split_kwargs: tp.KwargsLike = None, **select_indices_kwargs) -> tp.RangeLike: """Select a range. Passes `**select_indices_kwargs` to `Splitter.select_indices` to get the indices for selected splits and sets. If multiple ranges correspond to those indices, merges them using `Splitter.merge_split`.""" _, _, split_indices, set_indices = self.select_indices(**select_indices_kwargs) ranges = [] for i in split_indices: for j in set_indices: ranges.append(self.splits_arr[i, j]) if len(ranges) == 1: return ranges[0] if merge_split_kwargs is None: merge_split_kwargs = {} return self.merge_split(ranges, **merge_split_kwargs) @hybrid_method def remap_range( cls_or_self, range_: tp.FixRangeLike, target_index: tp.IndexLike, target_freq: tp.Optional[tp.FrequencyLike] = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, silence_warnings: bool = False, index: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.FixRangeLike: """Remap a range to a target index. If `index` and `target_index` are the same, returns the range. Otherwise, uses `vectorbtpro.base.resampling.base.Resampler.resample_source_mask` to resample the range into the target index. In such a case, `freq` and `target_freq` must be provided.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) if target_index is None: raise ValueError("Must provide target index") target_index = dt.prepare_dt_index(target_index) if index.equals(target_index): return range_ mask = cls_or_self.get_range_mask(range_, template_context=template_context, index=index) resampler = Resampler( source_index=index, target_index=target_index, source_freq=freq, target_freq=target_freq, ) target_mask = resampler.resample_source_mask(mask, jitted=jitted, silence_warnings=silence_warnings) return target_mask @classmethod def get_obj_index(cls, obj: tp.Any) -> tp.Index: """Get index from an object.""" if isinstance(obj, pd.Index): return obj if hasattr(obj, "index"): return obj.index if hasattr(obj, "wrapper"): return obj.wrapper.index raise ValueError("Must provide object index") @hybrid_method def get_ready_obj_range( cls_or_self, obj: tp.Any, range_: tp.FixRangeLike, remap_to_obj: bool = True, obj_index: tp.Optional[tp.IndexLike] = None, obj_freq: tp.Optional[tp.FrequencyLike] = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, silence_warnings: bool = False, index: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, return_obj_meta: bool = False, **ready_range_kwargs, ) -> tp.Any: """Get a range that is ready to be mapped into an array-like object. If the object is Pandas-like and `obj_index` is not None, searches for an index in the object using `Splitter.get_obj_index`. Once found, uses `Splitter.remap_range` to get the range that maps to the object index. Finally, uses `Splitter.get_ready_range` to convert the range into the one that can be used directly in indexing.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) if remap_to_obj and ( isinstance(obj, (pd.Index, pd.Series, pd.DataFrame, PandasIndexer)) or obj_index is not None ): if obj_index is None: obj_index = cls_or_self.get_obj_index(obj) target_range = cls_or_self.remap_range( range_, target_index=obj_index, target_freq=obj_freq, template_context=template_context, jitted=jitted, silence_warnings=silence_warnings, index=index, freq=freq, ) else: obj_index = index obj_freq = freq target_range = range_ ready_range_or_meta = cls_or_self.get_ready_range( target_range, template_context=template_context, index=obj_index, **ready_range_kwargs, ) if return_obj_meta: obj_meta = dict(index=obj_index, freq=obj_freq) return obj_meta, ready_range_or_meta return ready_range_or_meta @classmethod def take_range(cls, obj: tp.Any, ready_range: tp.ReadyRangeLike, point_wise: bool = False) -> tp.Any: """Take a ready range from an array-like object. Set `point_wise` to True to select one range point at a time and return a tuple.""" if isinstance(obj, (pd.Series, pd.DataFrame, PandasIndexer)): if point_wise: return tuple(obj.iloc[i] for i in np.arange(len(obj))[ready_range]) return obj.iloc[ready_range] if point_wise: return tuple(obj[i] for i in np.arange(len(obj))[ready_range]) return obj[ready_range] @hybrid_method def take_range_from_takeable( cls_or_self, takeable: Takeable, range_: tp.FixRangeLike, remap_to_obj: bool = True, obj_index: tp.Optional[tp.IndexLike] = None, obj_freq: tp.Optional[tp.FrequencyLike] = None, point_wise: bool = False, template_context: tp.KwargsLike = None, return_obj_meta: bool = False, return_meta: bool = False, **ready_obj_range_kwargs, ) -> tp.Any: """Take a range from a takeable object.""" takeable.assert_field_not_missing("obj") obj_meta, obj_range_meta = cls_or_self.get_ready_obj_range( takeable.obj, range_, remap_to_obj=takeable.remap_to_obj if takeable.remap_to_obj is not MISSING else remap_to_obj, obj_index=takeable.index if takeable.index is not MISSING else obj_index, obj_freq=takeable.freq if takeable.freq is not MISSING else obj_freq, template_context=template_context, return_obj_meta=True, return_meta=True, **ready_obj_range_kwargs, ) if isinstance(takeable.obj, CustomTemplate): template_context = merge_dicts( dict( range_=obj_range_meta["range_"], range_meta=obj_range_meta, point_wise=takeable.point_wise if takeable.point_wise is not MISSING else point_wise, ), template_context, ) obj_slice = substitute_templates(takeable.obj, template_context, eval_id="take_range") else: obj_slice = cls_or_self.take_range( takeable.obj, obj_range_meta["range_"], point_wise=takeable.point_wise if takeable.point_wise is not MISSING else point_wise, ) if return_obj_meta and return_meta: return obj_meta, obj_range_meta, obj_slice if return_obj_meta: return obj_meta, obj_slice if return_meta: return obj_range_meta, obj_slice return obj_slice def take( self, obj: tp.Any, split: tp.Optional[tp.Selection] = None, set_: tp.Optional[tp.Selection] = None, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_split: bool = True, squeeze_one_set: bool = True, into: tp.Optional[str] = None, remap_to_obj: bool = True, obj_index: tp.Optional[tp.IndexLike] = None, obj_freq: tp.Optional[tp.FrequencyLike] = None, range_format: str = "slice_or_any", point_wise: bool = False, attach_bounds: tp.Union[bool, str] = False, right_inclusive: bool = False, template_context: tp.KwargsLike = None, silence_warnings: bool = False, index_combine_kwargs: tp.KwargsLike = None, stack_axis: int = 1, stack_kwargs: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.Any: """Take all ranges from an array-like object and optionally column-stack them. Uses `Splitter.select_indices` to get the indices for selected splits and sets. Arguments `split_group_by` and `set_group_by` can be used to group splits and sets respectively. Ranges belonging to the same split and set group will be merged. For each index pair, resolves the source range using `Splitter.select_range` and `Splitter.get_ready_range`. Then, remaps this range into the object index using `Splitter.get_ready_obj_range` and takes the slice from the object using `Splitter.take_range`. If the object is a custom template, substitutes its instead of calling `Splitter.take_range`. Finally, uses `vectorbtpro.base.merging.column_stack_merge` (`stack_axis=1`) or `vectorbtpro.base.merging.row_stack_merge` (`stack_axis=0`) with `stack_kwargs` to merge the taken slices. If `attach_bounds` is enabled, measures the bounds of each range and makes it an additional level in the final index hierarchy. The argument supports the following options: * True, 'index', 'source', or 'source_index': Attach source (index) bounds * 'target' or 'target_index': Attach target (index) bounds * False: Do not attach Argument `into` supports the following options: * None: Series of range slices * 'stacked': Stack all slices into a single object * 'stacked_by_split': Stack set slices in each split and return a Series of objects * 'stacked_by_set': Stack split slices in each set and return a Series of objects * 'split_major_meta': Generator with ranges processed lazily in split-major order. Returns meta with indices and labels, and the generator. * 'set_major_meta': Generator with ranges processed lazily in set-major order. Returns meta with indices and labels, and the generator. Prepend any stacked option with "from_start_" (also "reset_") or "from_end_" to reset the index from start and from end respectively. Usage: * Roll a window and stack it along columns by keeping the index: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull( ... "BTC-USD", ... start="2020-01-01 UTC", ... end="2021-01-01 UTC" ... ) >>> splitter = vbt.Splitter.from_n_rolling( ... data.wrapper.index, ... 3, ... length=5 ... ) >>> splitter.take(data.close, into="stacked") split 0 1 2 Date 2020-01-01 00:00:00+00:00 7200.174316 NaN NaN 2020-01-02 00:00:00+00:00 6985.470215 NaN NaN 2020-01-03 00:00:00+00:00 7344.884277 NaN NaN 2020-01-04 00:00:00+00:00 7410.656738 NaN NaN 2020-01-05 00:00:00+00:00 7411.317383 NaN NaN 2020-06-29 00:00:00+00:00 NaN 9190.854492 NaN 2020-06-30 00:00:00+00:00 NaN 9137.993164 NaN 2020-07-01 00:00:00+00:00 NaN 9228.325195 NaN 2020-07-02 00:00:00+00:00 NaN 9123.410156 NaN 2020-07-03 00:00:00+00:00 NaN 9087.303711 NaN 2020-12-27 00:00:00+00:00 NaN NaN 26272.294922 2020-12-28 00:00:00+00:00 NaN NaN 27084.808594 2020-12-29 00:00:00+00:00 NaN NaN 27362.437500 2020-12-30 00:00:00+00:00 NaN NaN 28840.953125 2020-12-31 00:00:00+00:00 NaN NaN 29001.720703 ``` * Disgard the index and attach index bounds to the column hierarchy: ```pycon >>> splitter.take( ... data.close, ... into="reset_stacked", ... attach_bounds="index" ... ) split 0 1 \\ start 2020-01-01 00:00:00+00:00 2020-06-29 00:00:00+00:00 end 2020-01-06 00:00:00+00:00 2020-07-04 00:00:00+00:00 0 7200.174316 9190.854492 1 6985.470215 9137.993164 2 7344.884277 9228.325195 3 7410.656738 9123.410156 4 7411.317383 9087.303711 split 2 start 2020-12-27 00:00:00+00:00 end 2021-01-01 00:00:00+00:00 0 26272.294922 1 27084.808594 2 27362.437500 3 28840.953125 4 29001.720703 ``` """ if isinstance(attach_bounds, bool): if attach_bounds: attach_bounds = "source" else: attach_bounds = None index_bounds = False if attach_bounds is not None: if attach_bounds.lower() == "index": attach_bounds = "source" index_bounds = True if attach_bounds.lower() in ("source_index", "target_index"): attach_bounds = attach_bounds.split("_")[0] index_bounds = True if attach_bounds.lower() not in ("source", "target"): raise ValueError(f"Invalid attach_bounds: '{attach_bounds}'") if index_combine_kwargs is None: index_combine_kwargs = {} if stack_axis not in (0, 1): raise ValueError("Axis for stacking must be either 0 or 1") if stack_kwargs is None: stack_kwargs = {} split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) split_group_indices, set_group_indices, split_indices, set_indices = self.select_indices( split=split, set_=set_, split_group_by=split_group_by, set_group_by=set_group_by, ) if split is not None: split_labels = split_labels[split_group_indices] if set_ is not None: set_labels = set_labels[set_group_indices] n_splits = len(split_group_indices) n_sets = len(set_group_indices) one_split = n_splits == 1 and squeeze_one_split one_set = n_sets == 1 and squeeze_one_set one_range = one_split and one_set def _get_bounds(range_meta, obj_meta, obj_range_meta): if attach_bounds is not None: if attach_bounds.lower() == "source": if index_bounds: bounds = self.map_bounds_to_index( range_meta["start"], range_meta["stop"], right_inclusive=right_inclusive, freq=freq, ) else: if right_inclusive: bounds = (range_meta["start"], range_meta["stop"] - 1) else: bounds = (range_meta["start"], range_meta["stop"]) else: if index_bounds: bounds = self.map_bounds_to_index( obj_range_meta["start"], obj_range_meta["stop"], right_inclusive=right_inclusive, index=obj_meta["index"], freq=obj_meta["freq"], ) else: if right_inclusive: bounds = (obj_range_meta["start"], obj_range_meta["stop"] - 1) else: bounds = (obj_range_meta["start"], obj_range_meta["stop"]) else: bounds = (None, None) return bounds def _get_range_meta(i, j): split_idx = split_group_indices[i] set_idx = set_group_indices[j] range_ = self.select_range( split=PosSel(split_idx), set_=PosSel(set_idx), split_group_by=split_group_by, set_group_by=set_group_by, merge_split_kwargs=dict(template_context=template_context), ) range_meta = self.get_ready_range( range_, range_format=range_format, template_context=template_context, return_meta=True, ) obj_meta, obj_range_meta = self.get_ready_obj_range( obj, range_meta["range_"], remap_to_obj=remap_to_obj, obj_index=obj_index, obj_freq=obj_freq, range_format=range_format, template_context=template_context, silence_warnings=silence_warnings, freq=freq, return_obj_meta=True, return_meta=True, ) if isinstance(obj, CustomTemplate): _template_context = merge_dicts( dict( split_idx=split_idx, set_idx=set_idx, range_=obj_range_meta["range_"], range_meta=obj_range_meta, point_wise=point_wise, ), template_context, ) obj_slice = substitute_templates(obj, _template_context, eval_id="take_range") else: obj_slice = self.take_range(obj, obj_range_meta["range_"], point_wise=point_wise) bounds = _get_bounds(range_meta, obj_meta, obj_range_meta) return dict( split_idx=split_idx, set_idx=set_idx, range_meta=range_meta, obj_range_meta=obj_range_meta, obj_slice=obj_slice, bounds=bounds, ) def _attach_bounds(keys, range_bounds): range_bounds = pd.MultiIndex.from_tuples(range_bounds, names=["start", "end"]) if keys is None: return range_bounds clean_index_kwargs = dict(index_combine_kwargs) clean_index_kwargs.pop("ignore_ranges", None) return stack_indexes((keys, range_bounds), **clean_index_kwargs) if into is None: range_objs = [] range_bounds = [] for i in range(n_splits): for j in range(n_sets): range_meta = _get_range_meta(i, j) range_objs.append(range_meta["obj_slice"]) range_bounds.append(range_meta["bounds"]) if one_range: return range_objs[0] if one_set: keys = split_labels elif one_split: keys = set_labels else: keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs) if attach_bounds is not None: keys = _attach_bounds(keys, range_bounds) return pd.Series(range_objs, index=keys, dtype=object) if isinstance(into, str) and into.lower().startswith("reset_"): if stack_axis == 0: raise ValueError("Cannot use reset_index with stack_axis=0") stack_kwargs["reset_index"] = "from_start" into = into.lower().replace("reset_", "") if isinstance(into, str) and into.lower().startswith("from_start_"): if stack_axis == 0: raise ValueError("Cannot use reset_index with stack_axis=0") stack_kwargs["reset_index"] = "from_start" into = into.lower().replace("from_start_", "") if isinstance(into, str) and into.lower().startswith("from_end_"): if stack_axis == 0: raise ValueError("Cannot use reset_index with stack_axis=0") stack_kwargs["reset_index"] = "from_end" into = into.lower().replace("from_end_", "") if isinstance(into, str) and into.lower() in ("split_major_meta", "set_major_meta"): meta = { "split_group_indices": split_group_indices, "set_group_indices": set_group_indices, "split_indices": split_indices, "set_indices": set_indices, "n_splits": n_splits, "n_sets": n_sets, "split_labels": split_labels, "set_labels": set_labels, } if isinstance(into, str) and into.lower() == "split_major_meta": def _get_generator(): for i in range(n_splits): for j in range(n_sets): yield _get_range_meta(i, j) return meta, _get_generator() if isinstance(into, str) and into.lower() == "set_major_meta": def _get_generator(): for j in range(n_sets): for i in range(n_splits): yield _get_range_meta(i, j) return meta, _get_generator() if isinstance(into, str) and into.lower() == "stacked": range_objs = [] range_bounds = [] for i in range(n_splits): for j in range(n_sets): range_meta = _get_range_meta(i, j) range_objs.append(range_meta["obj_slice"]) range_bounds.append(range_meta["bounds"]) if one_range: return range_objs[0] if one_set: keys = split_labels elif one_split: keys = set_labels else: keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs) if attach_bounds is not None: keys = _attach_bounds(keys, range_bounds) _stack_kwargs = merge_dicts(dict(keys=keys), stack_kwargs) if stack_axis == 0: return row_stack_merge(range_objs, **_stack_kwargs) return column_stack_merge(range_objs, **_stack_kwargs) if isinstance(into, str) and into.lower() == "stacked_by_split": new_split_objs = [] one_set_bounds = [] for i in range(n_splits): range_objs = [] range_bounds = [] for j in range(n_sets): range_meta = _get_range_meta(i, j) range_objs.append(range_meta["obj_slice"]) range_bounds.append(range_meta["bounds"]) if one_set and squeeze_one_set: new_split_objs.append(range_objs[0]) one_set_bounds.append(range_bounds[0]) else: keys = set_labels if attach_bounds is not None: keys = _attach_bounds(keys, range_bounds) _stack_kwargs = merge_dicts(dict(keys=keys), stack_kwargs) if stack_axis == 0: new_split_objs.append(row_stack_merge(range_objs, **_stack_kwargs)) else: new_split_objs.append(column_stack_merge(range_objs, **_stack_kwargs)) if one_split and squeeze_one_split: return new_split_objs[0] if one_set and squeeze_one_set: if attach_bounds is not None: return pd.Series(new_split_objs, index=_attach_bounds(split_labels, one_set_bounds), dtype=object) return pd.Series(new_split_objs, index=split_labels, dtype=object) if isinstance(into, str) and into.lower() == "stacked_by_set": new_set_objs = [] one_split_bounds = [] for j in range(n_sets): range_objs = [] range_bounds = [] for i in range(n_splits): range_meta = _get_range_meta(i, j) range_objs.append(range_meta["obj_slice"]) range_bounds.append(range_meta["bounds"]) if one_split and squeeze_one_split: new_set_objs.append(range_objs[0]) one_split_bounds.append(range_bounds[0]) else: keys = split_labels if attach_bounds: keys = _attach_bounds(keys, range_bounds) _stack_kwargs = merge_dicts(dict(keys=keys), stack_kwargs) if stack_axis == 0: new_set_objs.append(row_stack_merge(range_objs, **_stack_kwargs)) else: new_set_objs.append(column_stack_merge(range_objs, **_stack_kwargs)) if one_set and squeeze_one_set: return new_set_objs[0] if one_split and squeeze_one_split: if attach_bounds is not None: return pd.Series(new_set_objs, index=_attach_bounds(set_labels, one_split_bounds), dtype=object) return pd.Series(new_set_objs, index=set_labels, dtype=object) raise ValueError(f"Invalid into: '{into}'") # ############# Applying ############# # @classmethod def parse_and_inject_takeables( cls, flat_ann_args: tp.FlatAnnArgs, eval_id: tp.Optional[tp.Hashable] = None, ) -> tp.FlatAnnArgs: """Parse `Takeable` instances from function annotations and inject them into flattened annotated arguments.""" new_flat_ann_args = dict() for k, v in flat_ann_args.items(): new_flat_ann_args[k] = v = dict(v) if "annotation" in v: if isinstance(v["annotation"], type) and issubclass(v["annotation"], Takeable): v["annotation"] = v["annotation"]() if isinstance(v["annotation"], Takeable) and v["annotation"].meets_eval_id(eval_id): if "value" in v: if not isinstance(v["value"], Takeable): v["value"] = v["annotation"].replace(obj=v["value"]) else: v["value"] = v["value"].merge_over(v["annotation"]) return new_flat_ann_args def apply( self, apply_func: tp.Callable, *apply_args, split: tp.Optional[tp.Selection] = None, set_: tp.Optional[tp.Selection] = None, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_split: bool = True, squeeze_one_set: bool = True, remap_to_obj: bool = True, obj_index: tp.Optional[tp.IndexLike] = None, obj_freq: tp.Optional[tp.FrequencyLike] = None, range_format: str = "slice_or_any", point_wise: bool = False, attach_bounds: tp.Union[bool, str] = False, right_inclusive: bool = False, template_context: tp.KwargsLike = None, silence_warnings: bool = False, index_combine_kwargs: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, iteration: str = "split_wise", execute_kwargs: tp.KwargsLike = None, filter_results: bool = True, raise_no_results: bool = True, merge_func: tp.Union[None, str, tuple, tp.Callable] = None, merge_kwargs: tp.KwargsLike = None, merge_all: bool = True, wrap_results: bool = True, eval_id: tp.Optional[tp.Hashable] = None, **apply_kwargs, ) -> tp.Any: """Apply a function on each range. Uses `Splitter.select_indices` to get the indices for selected splits and sets. Arguments `split_group_by` and `set_group_by` can be used to group splits and sets respectively. Ranges belonging to the same split and set group will be merged. For each index pair, in a lazily manner, resolves the source range using `Splitter.select_range` and `Splitter.get_ready_range`. Then, takes each argument from `args` and `kwargs` wrapped with `Takeable`, remaps the range into each object's index using `Splitter.get_ready_obj_range`, and takes the slice from that object using `Splitter.take_range`. The original object will be substituted by this slice. At the end, substitutes any templates in the prepared `args` and `kwargs` and saves the function and arguments for execution. For substitution, the following information is available: * `split/set_group_indices`: Indices corresponding to the selected row/column groups * `split/set_indices`: Indices corresponding to the selected rows/columns * `n_splits/sets`: Number of the selected rows/columns * `split/set_labels`: Labels corresponding to the selected row/column groups * `split/set_idx`: Index of the selected row/column * `split/set_label`: Label of the selected row/column * `range_`: Selected range ready for indexing (see `Splitter.get_ready_range`) * `range_meta`: Various information on the selected range * `obj_range_meta`: Various information on the range taken from each takeable argument. Positional arguments are denoted by position, keyword arguments are denoted by keys. * `args`: Positional arguments with ranges already selected * `kwargs`: Keyword arguments with ranges already selected * `bounds`: A tuple of either integer or index bounds. Can be source or target depending on `attach_bounds`. * `template_context`: Passed template context Since each range is processed lazily (that is, upon request), there are multiple iteration modes controlled by the argument `iteration`: * 'split_major': Flatten all ranges in split-major order and iterate over them * 'set_major': Flatten all ranges in set-major order and iterate over them * 'split_wise': Iterate over splits, while ranges in each split are processed sequentially * 'set_wise': Iterate over sets, while ranges in each set are processed sequentially The execution is done using `vectorbtpro.utils.execution.execute` with `execute_kwargs`. Once all results have been obtained, attempts to merge them using `merge_func` with `merge_kwargs` (all templates in it will be substituted as well), which can also be a string or a tuple of strings resolved using `vectorbtpro.base.merging.resolve_merge_func`. If `wrap_results` is enabled, packs the results into a Pandas object. If `apply_func` returns something complex, the resulting Pandas object will be of object data type. If `apply_func` returns a tuple (detected by the first returned result), a Pandas object is built for each element of that tuple. If `merge_all` is True, will merge all results in a flattened manner irrespective of the iteration mode. Otherwise, will merge by split/set. If `vectorbtpro.utils.execution.NoResult` is returned, will skip the current iteration and remove it from the final index. Usage: * Get the return of each data range: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull( ... "BTC-USD", ... start="2020-01-01 UTC", ... end="2021-01-01 UTC" ... ) >>> splitter = vbt.Splitter.from_n_rolling(data.wrapper.index, 5) >>> def apply_func(data): ... return data.close.iloc[-1] - data.close.iloc[0] >>> splitter.apply(apply_func, vbt.Takeable(data)) split 0 -1636.467285 1 3706.568359 2 2944.720703 3 -118.113281 4 17098.916016 dtype: float64 ``` * The same but by indexing manually: ```pycon >>> def apply_func(range_, data): ... data = data.iloc[range_] ... return data.close.iloc[-1] - data.close.iloc[0] >>> splitter.apply(apply_func, vbt.Rep("range_"), data) split 0 -1636.467285 1 3706.568359 2 2944.720703 3 -118.113281 4 17098.916016 dtype: float64 ``` * Divide into two windows, each consisting of 50% train and 50% test, compute SMA for each range, and row-stack the outputs of each set upon merging: ```pycon >>> splitter = vbt.Splitter.from_n_rolling(data.wrapper.index, 2, split=0.5) >>> def apply_func(data): ... return data.run("SMA", 10).real >>> splitter.apply( ... apply_func, ... vbt.Takeable(data), ... merge_func="row_stack" ... ).unstack("set").vbt.drop_levels("split", axis=0).vbt.plot().show() ``` ![](/assets/images/api/Splitter_apply.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Splitter_apply.dark.svg#only-dark){: .iimg loading=lazy } """ if isinstance(attach_bounds, bool): if attach_bounds: attach_bounds = "source" else: attach_bounds = None index_bounds = False if attach_bounds is not None: if attach_bounds.lower() == "index": attach_bounds = "source" index_bounds = True if attach_bounds.lower() in ("source_index", "target_index"): attach_bounds = attach_bounds.split("_")[0] index_bounds = True if attach_bounds.lower() not in ("source", "target"): raise ValueError(f"Invalid attach_bounds: '{attach_bounds}'") if index_combine_kwargs is None: index_combine_kwargs = {} if execute_kwargs is None: execute_kwargs = {} parsed_merge_func = parse_merge_func(apply_func, eval_id=eval_id) if parsed_merge_func is not None: if merge_func is not None: raise ValueError( f"Two conflicting merge functions: {parsed_merge_func} (annotations) and {merge_func} (merge_func)" ) merge_func = parsed_merge_func if merge_kwargs is None: merge_kwargs = {} split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) split_group_indices, set_group_indices, split_indices, set_indices = self.select_indices( split=split, set_=set_, split_group_by=split_group_by, set_group_by=set_group_by, ) if split is not None: split_labels = split_labels[split_group_indices] if set_ is not None: set_labels = set_labels[set_group_indices] n_splits = len(split_group_indices) n_sets = len(set_group_indices) one_split = n_splits == 1 and squeeze_one_split one_set = n_sets == 1 and squeeze_one_set one_range = one_split and one_set template_context = merge_dicts( { "splitter": self, "index": self.index, "split_group_indices": split_group_indices, "set_group_indices": set_group_indices, "split_indices": split_indices, "set_indices": set_indices, "n_splits": n_splits, "n_sets": n_sets, "split_labels": split_labels, "set_labels": set_labels, "one_split": one_split, "one_set": one_set, "one_range": one_range, }, template_context, ) template_context["eval_id"] = eval_id if has_annotatables(apply_func): ann_args = annotate_args( apply_func, apply_args, apply_kwargs, attach_annotations=True, ) flat_ann_args = flatten_ann_args(ann_args) flat_ann_args = self.parse_and_inject_takeables(flat_ann_args, eval_id=eval_id) ann_args = unflatten_ann_args(flat_ann_args) apply_args, apply_kwargs = ann_args_to_args(ann_args) def _get_range_meta(i, j, _template_context): split_idx = split_group_indices[i] set_idx = set_group_indices[j] range_ = self.select_range( split=PosSel(split_idx), set_=PosSel(set_idx), split_group_by=split_group_by, set_group_by=set_group_by, merge_split_kwargs=dict(template_context=_template_context), ) range_meta = self.get_ready_range( range_, range_format=range_format, template_context=_template_context, return_meta=True, ) return range_meta def _take_args(args, range_, _template_context): obj_meta = {} obj_range_meta = {} new_args = () if args is not None: for i, v in enumerate(args): if isinstance(v, Takeable) and v.meets_eval_id(eval_id): _obj_meta, _obj_range_meta, obj_slice = self.take_range_from_takeable( v, range_, remap_to_obj=remap_to_obj, obj_index=obj_index, obj_freq=obj_freq, range_format=range_format, point_wise=point_wise, template_context=_template_context, silence_warnings=silence_warnings, freq=freq, return_obj_meta=True, return_meta=True, ) new_args += (obj_slice,) obj_meta[i] = _obj_meta obj_range_meta[i] = _obj_range_meta else: new_args += (v,) return obj_meta, obj_range_meta, new_args def _take_kwargs(kwargs, range_, _template_context): obj_meta = {} obj_range_meta = {} new_kwargs = {} if kwargs is not None: for k, v in kwargs.items(): if isinstance(v, Takeable) and v.meets_eval_id(eval_id): _obj_meta, _obj_range_meta, obj_slice = self.take_range_from_takeable( v, range_, remap_to_obj=remap_to_obj, obj_index=obj_index, obj_freq=obj_freq, range_format=range_format, point_wise=point_wise, template_context=_template_context, silence_warnings=silence_warnings, freq=freq, return_obj_meta=True, return_meta=True, ) new_kwargs[k] = obj_slice obj_meta[k] = _obj_meta obj_range_meta[k] = _obj_range_meta else: new_kwargs[k] = v return obj_meta, obj_range_meta, new_kwargs def _get_bounds(range_meta, _template_context): if attach_bounds is not None: if isinstance(attach_bounds, str) and attach_bounds.lower() == "source": if index_bounds: bounds = self.map_bounds_to_index( range_meta["start"], range_meta["stop"], right_inclusive=right_inclusive, freq=freq, ) else: if right_inclusive: bounds = (range_meta["start"], range_meta["stop"] - 1) else: bounds = (range_meta["start"], range_meta["stop"]) else: obj_meta, obj_range_meta = self.get_ready_obj_range( self.index, range_meta["range_"], remap_to_obj=remap_to_obj, obj_index=obj_index, obj_freq=obj_freq, range_format=range_format, template_context=_template_context, silence_warnings=silence_warnings, freq=freq, return_obj_meta=True, return_meta=True, ) if index_bounds: bounds = self.map_bounds_to_index( obj_range_meta["start"], obj_range_meta["stop"], right_inclusive=right_inclusive, index=obj_meta["index"], freq=obj_meta["freq"], ) else: if right_inclusive: bounds = ( obj_range_meta["start"], obj_range_meta["stop"] - 1, ) else: bounds = ( obj_range_meta["start"], obj_range_meta["stop"], ) else: bounds = (None, None) return bounds bounds = {} def _get_task(i, j, _bounds=bounds): split_idx = split_group_indices[i] set_idx = set_group_indices[j] _template_context = merge_dicts( { "split_idx": split_idx, "split_label": split_labels[i], "set_idx": set_idx, "set_label": set_labels[j], }, template_context, ) range_meta = _get_range_meta(i, j, _template_context) _template_context = merge_dicts( dict(range_=range_meta["range_"], range_meta=range_meta), _template_context, ) obj_meta1, obj_range_meta1, _apply_args = _take_args(apply_args, range_meta["range_"], _template_context) obj_meta2, obj_range_meta2, _apply_kwargs = _take_kwargs( apply_kwargs, range_meta["range_"], _template_context ) obj_meta = {**obj_meta1, **obj_meta2} obj_range_meta = {**obj_range_meta1, **obj_range_meta2} _bounds[(i, j)] = _get_bounds(range_meta, _template_context) _template_context = merge_dicts( dict( obj_meta=obj_meta, obj_range_meta=obj_range_meta, apply_args=_apply_args, apply_kwargs=_apply_kwargs, bounds=_bounds[(i, j)], ), _template_context, ) _apply_func = substitute_templates(apply_func, _template_context, eval_id="apply_func") _apply_args = substitute_templates(_apply_args, _template_context, eval_id="apply_args") _apply_kwargs = substitute_templates(_apply_kwargs, _template_context, eval_id="apply_kwargs") return Task(_apply_func, *_apply_args, **_apply_kwargs) def _attach_bounds(keys, range_bounds): range_bounds = pd.MultiIndex.from_tuples(range_bounds, names=["start", "end"]) if keys is None: return range_bounds clean_index_kwargs = dict(index_combine_kwargs) clean_index_kwargs.pop("ignore_ranges", None) return stack_indexes((keys, range_bounds), **clean_index_kwargs) if iteration.lower() == "split_major": def _get_task_generator(): for i in range(n_splits): for j in range(n_sets): yield _get_task(i, j) tasks = _get_task_generator() keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs) if eval_id is not None: new_keys = [] for key in keys: if isinstance(keys, pd.MultiIndex): new_keys.append((MISSING, *key)) else: new_keys.append((MISSING, key)) keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names)) execute_kwargs = merge_dicts(dict(show_progress=False if one_split and one_set else None), execute_kwargs) results = execute(tasks, size=n_splits * n_sets, keys=keys, **execute_kwargs) elif iteration.lower() == "set_major": def _get_task_generator(): for j in range(n_sets): for i in range(n_splits): yield _get_task(i, j) tasks = _get_task_generator() keys = combine_indexes((set_labels, split_labels), **index_combine_kwargs) if eval_id is not None: new_keys = [] for key in keys: if isinstance(keys, pd.MultiIndex): new_keys.append((MISSING, *key)) else: new_keys.append((MISSING, key)) keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names)) execute_kwargs = merge_dicts(dict(show_progress=False if one_split and one_set else None), execute_kwargs) results = execute(tasks, size=n_splits * n_sets, keys=keys, **execute_kwargs) elif iteration.lower() == "split_wise": def _process_chunk_tasks(chunk_tasks): results = [] for func, args, kwargs in chunk_tasks: results.append(func(*args, **kwargs)) return results def _get_task_generator(): for i in range(n_splits): chunk_tasks = [] for j in range(n_sets): chunk_tasks.append(_get_task(i, j)) yield Task(_process_chunk_tasks, chunk_tasks) tasks = _get_task_generator() keys = split_labels if eval_id is not None: new_keys = [] for key in keys: if isinstance(keys, pd.MultiIndex): new_keys.append((MISSING, *key)) else: new_keys.append((MISSING, key)) keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names)) execute_kwargs = merge_dicts(dict(show_progress=False if one_split else None), execute_kwargs) results = execute(tasks, size=n_splits, keys=keys, **execute_kwargs) elif iteration.lower() == "set_wise": def _process_chunk_tasks(chunk_tasks): results = [] for func, args, kwargs in chunk_tasks: results.append(func(*args, **kwargs)) return results def _get_task_generator(): for j in range(n_sets): chunk_tasks = [] for i in range(n_splits): chunk_tasks.append(_get_task(i, j)) yield Task(_process_chunk_tasks, chunk_tasks) tasks = _get_task_generator() keys = set_labels if eval_id is not None: new_keys = [] for key in keys: if isinstance(keys, pd.MultiIndex): new_keys.append((MISSING, *key)) else: new_keys.append((MISSING, key)) keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names)) execute_kwargs = merge_dicts(dict(show_progress=False if one_set else None), execute_kwargs) results = execute(tasks, size=n_sets, keys=keys, **execute_kwargs) else: raise ValueError(f"Invalid iteration: '{iteration}'") if merge_all: if iteration.lower() in ("split_wise", "set_wise"): results = [result for _results in results for result in _results] if one_range: if results[0] is NoResult: if raise_no_results: raise NoResultsException return NoResult return results[0] if iteration.lower() in ("split_major", "split_wise"): if one_set: keys = split_labels elif one_split: keys = set_labels else: keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs) if attach_bounds is not None: range_bounds = [] for i in range(n_splits): for j in range(n_sets): range_bounds.append(bounds[(i, j)]) keys = _attach_bounds(keys, range_bounds) else: if one_set: keys = split_labels elif one_split: keys = set_labels else: keys = combine_indexes((set_labels, split_labels), **index_combine_kwargs) if attach_bounds is not None: range_bounds = [] for j in range(n_sets): for i in range(n_splits): range_bounds.append(bounds[(i, j)]) keys = _attach_bounds(keys, range_bounds) if filter_results: try: results, keys = filter_out_no_results(results, keys=keys) except NoResultsException as e: if raise_no_results: raise e return NoResult no_results_filtered = True else: no_results_filtered = False def _wrap_output(_results): try: return pd.Series(_results, index=keys) except Exception as e: return pd.Series(_results, index=keys, dtype=object) if merge_func is not None: template_context["tasks"] = tasks template_context["keys"] = keys if is_merge_func_from_config(merge_func): merge_kwargs = merge_dicts( dict( keys=keys, filter_results=not no_results_filtered, raise_no_results=raise_no_results, ), merge_kwargs, ) if isinstance(merge_func, MergeFunc): merge_func = merge_func.replace( merge_kwargs=merge_kwargs, context=template_context, ) else: merge_func = MergeFunc( merge_func, merge_kwargs=merge_kwargs, context=template_context, ) return merge_func(results) if wrap_results: if isinstance(results[0], tuple): return tuple(map(_wrap_output, zip(*results))) return _wrap_output(results) return results if iteration.lower() == "split_major": new_results = [] for i in range(n_splits): new_results.append(results[i * n_sets : (i + 1) * n_sets]) results = new_results elif iteration.lower() == "set_major": new_results = [] for i in range(n_sets): new_results.append(results[i * n_splits : (i + 1) * n_splits]) results = new_results if one_range: if results[0][0] is NoResult: if raise_no_results: raise NoResultsException return NoResult return results[0][0] split_bounds = [] if attach_bounds is not None: for i in range(n_splits): split_bounds.append([]) for j in range(n_sets): split_bounds[-1].append(bounds[(i, j)]) set_bounds = [] if attach_bounds is not None: for j in range(n_sets): set_bounds.append([]) for i in range(n_splits): set_bounds[-1].append(bounds[(i, j)]) if iteration.lower() in ("split_major", "split_wise"): major_keys = split_labels minor_keys = set_labels major_bounds = split_bounds minor_bounds = set_bounds one_major = one_split one_minor = one_set else: major_keys = set_labels minor_keys = split_labels major_bounds = set_bounds minor_bounds = split_bounds one_major = one_set one_minor = one_split if merge_func is not None: merged_results = [] keep_major_indices = [] for i, _results in enumerate(results): if one_minor: if _results[0] is not NoResult: merged_results.append(_results[0]) keep_major_indices.append(i) else: _template_context = dict(template_context) _template_context["tasks"] = tasks if attach_bounds is not None: minor_keys_wbounds = _attach_bounds(minor_keys, major_bounds[i]) else: minor_keys_wbounds = minor_keys if filter_results: _results, minor_keys_wbounds = filter_out_no_results( _results, keys=minor_keys_wbounds, raise_error=False, ) no_results_filtered = True else: no_results_filtered = False if len(_results) > 0: _template_context["keys"] = minor_keys_wbounds if is_merge_func_from_config(merge_func): _merge_kwargs = merge_dicts( dict( keys=minor_keys_wbounds, filter_results=not no_results_filtered, raise_no_results=False, ), merge_kwargs, ) else: _merge_kwargs = merge_kwargs if isinstance(merge_func, MergeFunc): _merge_func = merge_func.replace( merge_kwargs=_merge_kwargs, context=_template_context, ) else: _merge_func = MergeFunc( merge_func, merge_kwargs=_merge_kwargs, context=_template_context, ) _result = _merge_func(_results) if _result is not NoResult: merged_results.append(_result) keep_major_indices.append(i) if len(merged_results) == 0: if raise_no_results: raise NoResultsException return NoResult if len(merged_results) < len(major_keys): major_keys = major_keys[keep_major_indices] if one_major: return merged_results[0] if wrap_results: def _wrap_output(_results): try: return pd.Series(_results, index=major_keys) except Exception as e: return pd.Series(_results, index=major_keys, dtype=object) if isinstance(merged_results[0], tuple): return tuple(map(_wrap_output, zip(*merged_results))) return _wrap_output(merged_results) return merged_results if one_major: results = results[0] elif one_minor: results = [_results[0] for _results in results] if wrap_results: def _wrap_output(_results): if one_minor: if attach_bounds is not None: major_keys_wbounds = _attach_bounds(major_keys, minor_bounds[0]) else: major_keys_wbounds = major_keys if filter_results: try: _results, major_keys_wbounds = filter_out_no_results(_results, keys=major_keys_wbounds) except NoResultsException as e: if raise_no_results: raise e return NoResult try: return pd.Series(_results, index=major_keys_wbounds) except Exception as e: return pd.Series(_results, index=major_keys_wbounds, dtype=object) if one_major: if attach_bounds is not None: minor_keys_wbounds = _attach_bounds(minor_keys, major_bounds[0]) else: minor_keys_wbounds = minor_keys if filter_results: try: _results, major_keys_wbounds = filter_out_no_results(_results, keys=minor_keys_wbounds) except NoResultsException as e: if raise_no_results: raise e return NoResult try: return pd.Series(_results, index=minor_keys_wbounds) except Exception as e: return pd.Series(_results, index=minor_keys_wbounds, dtype=object) new_results = [] keep_major_indices = [] for i, r in enumerate(_results): if attach_bounds is not None: minor_keys_wbounds = _attach_bounds(minor_keys, major_bounds[i]) else: minor_keys_wbounds = minor_keys if filter_results: r, minor_keys_wbounds = filter_out_no_results( r, keys=minor_keys_wbounds, raise_error=False, ) if len(r) > 0: try: new_r = pd.Series(r, index=minor_keys_wbounds) except Exception as e: new_r = pd.Series(r, index=minor_keys_wbounds, dtype=object) new_results.append(new_r) keep_major_indices.append(i) if len(new_results) == 0: if raise_no_results: raise NoResultsException return NoResult if len(new_results) < len(major_keys): _major_keys = major_keys[keep_major_indices] else: _major_keys = major_keys try: return pd.Series(new_results, index=_major_keys) except Exception as e: return pd.Series(new_results, index=_major_keys, dtype=object) if one_major or one_minor: n_results = 1 for r in results: if isinstance(r, tuple): n_results = len(r) break if n_results > 1: new_results = [] for k in range(n_results): new_results.append([]) for i in range(len(results)): if results[i] is NoResult: new_results[-1].append(results[i]) else: new_results[-1].append(results[i][k]) return tuple(map(_wrap_output, new_results)) else: n_results = 1 for r in results: for _r in r: if isinstance(_r, tuple): n_results = len(_r) break if n_results > 1: break if n_results > 1: new_results = [] for k in range(n_results): new_results.append([]) for i in range(len(results)): new_results[-1].append([]) for j in range(len(results[0])): if results[i][j] is NoResult: new_results[-1][-1].append(results[i][j]) else: new_results[-1][-1].append(results[i][j][k]) return tuple(map(_wrap_output, new_results)) return _wrap_output(results) if filter_results: try: results = filter_out_no_results(results) except NoResultsException as e: if raise_no_results: raise e return NoResult return results # ############# Splits ############# # def shuffle_splits( self: SplitterT, size: tp.Union[None, str, int] = None, replace: bool = False, p: tp.Optional[tp.Array1d] = None, seed: tp.Optional[int] = None, wrapper_kwargs: tp.KwargsLike = None, **init_kwargs, ) -> SplitterT: """Shuffle splits.""" if wrapper_kwargs is None: wrapper_kwargs = {} rng = np.random.default_rng(seed=seed) if size is None: size = self.n_splits new_split_indices = rng.choice(np.arange(self.n_splits), size=size, replace=replace, p=p) new_splits_arr = self.splits_arr[new_split_indices] new_index = self.wrapper.index[new_split_indices] if "index" not in wrapper_kwargs: wrapper_kwargs["index"] = new_index new_wrapper = self.wrapper.replace(**wrapper_kwargs) return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs) def break_up_splits( self: SplitterT, new_split: tp.SplitLike, sort: bool = False, template_context: tp.KwargsLike = None, wrapper_kwargs: tp.KwargsLike = None, init_kwargs: tp.KwargsLike = None, **split_range_kwargs, ) -> SplitterT: """Split each split into multiple splits. If there are multiple sets, make sure to merge them into one beforehand. Arguments `new_split` and `**split_range_kwargs` are passed to `Splitter.split_range`.""" if self.n_sets > 1: raise ValueError("Cannot break up splits with more than one set. Merge sets first.") if wrapper_kwargs is None: wrapper_kwargs = {} if init_kwargs is None: init_kwargs = {} split_range_kwargs = dict(split_range_kwargs) wrap_with_fixrange = split_range_kwargs.pop("wrap_with_fixrange", None) if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange: raise ValueError("Argument wrap_with_fixrange must be True or None") split_range_kwargs["wrap_with_fixrange"] = wrap_with_fixrange new_splits_arr = [] new_index = [] range_starts = [] for i, split in enumerate(self.splits_arr): new_ranges = self.split_range(split[0], new_split, template_context=template_context, **split_range_kwargs) for j, range_ in enumerate(new_ranges): if sort: range_starts.append(self.get_range_bounds(range_, template_context=template_context)[0]) new_splits_arr.append([range_]) if isinstance(self.split_labels, pd.MultiIndex): new_index.append((*self.split_labels[i], j)) else: new_index.append((self.split_labels[i], j)) new_splits_arr = np.asarray(new_splits_arr, dtype=object) new_index = pd.MultiIndex.from_tuples(new_index, names=[*self.split_labels.names, "split_part"]) if sort: sorted_indices = np.argsort(range_starts) new_splits_arr = new_splits_arr[sorted_indices] new_index = new_index[sorted_indices] if "index" not in wrapper_kwargs: wrapper_kwargs["index"] = new_index new_wrapper = self.wrapper.replace(**wrapper_kwargs) return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs) # ############# Sets ############# # def split_set( self: SplitterT, new_split: tp.SplitLike, column: tp.Optional[tp.Hashable] = None, new_set_labels: tp.Optional[tp.Sequence[tp.Hashable]] = None, wrapper_kwargs: tp.KwargsLike = None, init_kwargs: tp.KwargsLike = None, **split_range_kwargs, ) -> SplitterT: """Split a set (column) into multiple sets (columns). Arguments `new_split` and `**split_range_kwargs` are passed to `Splitter.split_range`. Column must be provided if there are two or more sets. Use `new_set_labels` to specify the labels of the new sets; it must have the same length as there are new ranges in the new split. To provide final labels, define `columns` in `wrapper_kwargs`.""" if self.n_sets == 0: raise ValueError("There are no sets to split") if self.n_sets > 1: if column is None: raise ValueError("Must provide column for multiple sets") if not isinstance(column, int): column = self.set_labels.get_indexer([column])[0] if column == -1: raise ValueError(f"Column '{column}' not found") else: column = 0 if wrapper_kwargs is None: wrapper_kwargs = {} if init_kwargs is None: init_kwargs = {} split_range_kwargs = dict(split_range_kwargs) wrap_with_fixrange = split_range_kwargs.pop("wrap_with_fixrange", None) if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange: raise ValueError("Argument wrap_with_fixrange must be True or None") split_range_kwargs["wrap_with_fixrange"] = wrap_with_fixrange new_splits_arr = [] for split in self.splits_arr: new_ranges = self.split_range(split[column], new_split, **split_range_kwargs) new_splits_arr.append([*split[:column], *new_ranges, *split[column + 1 :]]) new_splits_arr = np.asarray(new_splits_arr, dtype=object) if "columns" not in wrapper_kwargs: wrapper_kwargs = dict(wrapper_kwargs) n_new_sets = new_splits_arr.shape[1] - self.n_sets + 1 if new_set_labels is None: old_set_label = self.set_labels[column] if isinstance(old_set_label, str) and len(old_set_label.split("+")) == n_new_sets: new_set_labels = old_set_label.split("+") else: new_set_labels = [str(old_set_label) + "/" + str(i) for i in range(n_new_sets)] if len(new_set_labels) != n_new_sets: raise ValueError(f"Argument new_set_labels must have length {n_new_sets}, not {len(new_set_labels)}") new_columns = self.set_labels.copy() new_columns = new_columns.delete(column) new_columns = new_columns.insert(column, new_set_labels) wrapper_kwargs["columns"] = new_columns new_wrapper = self.wrapper.replace(**wrapper_kwargs) return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs) def merge_sets( self: SplitterT, columns: tp.Optional[tp.Iterable[tp.Hashable]] = None, new_set_label: tp.Optional[tp.Hashable] = None, insert_at_last: bool = False, wrapper_kwargs: tp.KwargsLike = None, init_kwargs: tp.KwargsLike = None, **merge_split_kwargs, ) -> SplitterT: """Merge multiple sets (columns) into a set (column). Arguments `**merge_split_kwargs` are passed to `Splitter.merge_split`. If columns are not provided, merges all columns. If provided and `insert_at_last` is True, a new column is inserted at the position of the last column. Use `new_set_label` to specify the label of the new set. To provide final labels, define `columns` in `wrapper_kwargs`.""" if self.n_sets < 2: raise ValueError("There are no sets to merge") if columns is None: columns = range(len(self.set_labels)) new_columns = [] for column in columns: if not isinstance(column, int): column = self.set_labels.get_indexer([column])[0] if column == -1: raise ValueError(f"Column '{column}' not found") new_columns.append(column) columns = sorted(new_columns) if wrapper_kwargs is None: wrapper_kwargs = {} if init_kwargs is None: init_kwargs = {} merge_split_kwargs = dict(merge_split_kwargs) wrap_with_fixrange = merge_split_kwargs.pop("wrap_with_fixrange", None) if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange: raise ValueError("Argument wrap_with_fixrange must be True or None") merge_split_kwargs["wrap_with_fixrange"] = wrap_with_fixrange new_splits_arr = [] for split in self.splits_arr: split_to_merge = [] for j, range_ in enumerate(split): if j in columns: split_to_merge.append(range_) new_range = self.merge_split(split_to_merge, **merge_split_kwargs) new_split = [] for j in range(self.n_sets): if j not in columns: new_split.append(split[j]) else: if insert_at_last: if j == columns[-1]: new_split.append(new_range) else: if j == columns[0]: new_split.append(new_range) new_splits_arr.append(new_split) new_splits_arr = np.asarray(new_splits_arr, dtype=object) if "columns" not in wrapper_kwargs: wrapper_kwargs = dict(wrapper_kwargs) if new_set_label is None: old_set_labels = self.set_labels[columns] can_aggregate = True prefix = None suffix = None for i, old_set_label in enumerate(old_set_labels): if not isinstance(old_set_label, str): can_aggregate = False break _prefix = "/".join(old_set_label.split("/")[:-1]) _suffix = old_set_label.split("/")[-1] if not _suffix.isdigit(): can_aggregate = False break _suffix = int(_suffix) if prefix is None: prefix = _prefix suffix = _suffix continue if suffix != 0: can_aggregate = False break if not _prefix == prefix or _suffix != i: can_aggregate = False break if can_aggregate and prefix + "/%d" % len(old_set_labels) not in self.set_labels: new_set_label = prefix else: new_set_label = "+".join(map(str, old_set_labels)) new_columns = self.set_labels.copy() new_columns = new_columns.delete(columns) if insert_at_last: new_columns = new_columns.insert(columns[-1] - len(columns) + 1, new_set_label) else: new_columns = new_columns.insert(columns[0], new_set_label) wrapper_kwargs["columns"] = new_columns if "ndim" not in wrapper_kwargs: if len(wrapper_kwargs["columns"]) == 1: wrapper_kwargs["ndim"] = 1 new_wrapper = self.wrapper.replace(**wrapper_kwargs) return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs) # ############# Bounds ############# # @hybrid_method def map_bounds_to_index( cls_or_self, start: int, stop: int, right_inclusive: bool = False, index: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.Tuple[tp.Any, tp.Any]: """Map bounds to index.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) if right_inclusive: return index[start], index[stop - 1] if stop == len(index): freq = BaseIDXAccessor(index, freq=freq).any_freq if freq is None: raise ValueError("Must provide freq") return index[start], index[stop - 1] + freq return index[start], index[stop] @hybrid_method def get_range_bounds( cls_or_self, range_: tp.FixRangeLike, index_bounds: bool = False, right_inclusive: bool = False, check_constant: bool = True, template_context: tp.KwargsLike = None, index: tp.Optional[tp.IndexLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, ) -> tp.Tuple[tp.Any, tp.Any]: """Get the left (inclusive) and right (exclusive) bound of a range. !!! note Even when mapped to the index, the right bound is always exclusive.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) range_meta = cls_or_self.get_ready_range( range_, template_context=template_context, index=index, return_meta=True, ) if check_constant and not range_meta["is_constant"]: raise ValueError("Range is not constant") if index_bounds: range_meta["start"], range_meta["stop"] = cls_or_self.map_bounds_to_index( range_meta["start"], range_meta["stop"], right_inclusive=right_inclusive, index=index, freq=freq, ) else: if right_inclusive: range_meta["stop"] = range_meta["stop"] - 1 return range_meta["start"], range_meta["stop"] def get_bounds_arr( self, index_bounds: bool = False, right_inclusive: bool = False, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, template_context: tp.KwargsLike = None, **range_bounds_kwargs, ) -> tp.BoundsArray: """Three-dimensional integer array with bounds. First axis represents splits. Second axis represents sets. Third axis represents bounds. Each range is getting selected using `Splitter.select_range` and then measured using `Splitter.get_range_bounds`. Keyword arguments `**kwargs` are passed to `Splitter.get_range_bounds`.""" if index_bounds: dtype = self.index.dtype else: dtype = int_ split_group_by = self.get_split_grouper(split_group_by=split_group_by) n_splits = self.get_n_splits(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) n_sets = self.get_n_sets(set_group_by=set_group_by) try: bounds = np.empty((n_splits, n_sets, 2), dtype=dtype) except TypeError as e: bounds = np.empty((n_splits, n_sets, 2), dtype=object) for i in range(n_splits): for j in range(n_sets): range_ = self.select_range( split=PosSel(i), set_=PosSel(j), split_group_by=split_group_by, set_group_by=set_group_by, merge_split_kwargs=dict(template_context=template_context), ) bounds[i, j, :] = self.get_range_bounds( range_, index_bounds=index_bounds, right_inclusive=right_inclusive, template_context=template_context, **range_bounds_kwargs, ) return bounds @property def bounds_arr(self) -> tp.BoundsArray: """`Splitter.get_bounds_arr` with default arguments.""" return self.get_bounds_arr() def get_bounds( self, index_bounds: bool = False, right_inclusive: bool = False, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_split: bool = True, squeeze_one_set: bool = True, index_combine_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Boolean Series/DataFrame where index are bounds and columns are splits stacked together. Keyword arguments `**kwargs` are passed to `Splitter.get_bounds_arr`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) bounds_arr = self.get_bounds_arr( index_bounds=index_bounds, right_inclusive=right_inclusive, split_group_by=split_group_by, set_group_by=set_group_by, **kwargs, ) out = bounds_arr.reshape((-1, 2)) one_split = len(split_labels) == 1 and squeeze_one_split one_set = len(set_labels) == 1 and squeeze_one_set new_columns = pd.Index(["start", "end"], name="bound") if one_split and one_set: return pd.Series(out[0], index=new_columns) if one_split: return pd.DataFrame(out, index=set_labels, columns=new_columns) if one_set: return pd.DataFrame(out, index=split_labels, columns=new_columns) if index_combine_kwargs is None: index_combine_kwargs = {} new_index = combine_indexes((split_labels, set_labels), **index_combine_kwargs) return pd.DataFrame(out, index=new_index, columns=new_columns) @property def bounds(self) -> tp.Frame: """`Splitter.get_bounds` with default arguments.""" return self.get_bounds() @property def index_bounds(self) -> tp.Frame: """`Splitter.get_bounds` with `index_bounds=True`.""" return self.get_bounds(index_bounds=True) def get_duration(self, **kwargs) -> tp.Series: """Get duration.""" bounds = self.get_bounds(right_inclusive=False, **kwargs) return (bounds["end"] - bounds["start"]).rename("duration") @property def duration(self) -> tp.Series: """`Splitter.get_duration` with default arguments.""" return self.get_duration() @property def index_duration(self) -> tp.Series: """`Splitter.get_duration` with `index_bounds=True`.""" return self.get_duration(index_bounds=True) # ############# Masks ############# # @hybrid_method def get_range_mask( cls_or_self, range_: tp.FixRangeLike, template_context: tp.KwargsLike = None, index: tp.Optional[tp.IndexLike] = None, ) -> tp.Array1d: """Get the mask of a range.""" if index is None: if isinstance(cls_or_self, type): raise ValueError("Must provide index") index = cls_or_self.index else: index = dt.prepare_dt_index(index) range_ = cls_or_self.get_ready_range( range_, allow_zero_len=True, template_context=template_context, index=index, ) if isinstance(range_, np.ndarray) and range_.dtype == np.bool_: return range_ mask = np.full(len(index), False) mask[range_] = True return mask def get_iter_split_mask_arrs( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.Iterator[tp.Array2d]: """Generator of two-dimensional boolean arrays, one per split. First axis represents sets. Second axis represents index. Keyword arguments `**kwargs` are passed to `Splitter.get_range_mask`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) n_splits = self.get_n_splits(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) n_sets = self.get_n_sets(set_group_by=set_group_by) for i in range(n_splits): out = np.full((n_sets, len(self.index)), False) for j in range(n_sets): range_ = self.select_range( split=PosSel(i), set_=PosSel(j), split_group_by=split_group_by, set_group_by=set_group_by, merge_split_kwargs=dict(template_context=template_context), ) out[j, :] = self.get_range_mask(range_, template_context=template_context, **kwargs) yield out @property def iter_split_mask_arrs(self) -> tp.Iterator[tp.Array2d]: """`Splitter.get_iter_split_mask_arrs` with default arguments.""" return self.get_iter_split_mask_arrs() def get_iter_set_mask_arrs( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.Iterator[tp.Array2d]: """Generator of two-dimensional boolean arrays, one per set. First axis represents splits. Second axis represents index. Keyword arguments `**kwargs` are passed to `Splitter.get_range_mask`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) n_splits = self.get_n_splits(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) n_sets = self.get_n_sets(set_group_by=set_group_by) for j in range(n_sets): out = np.full((n_splits, len(self.index)), False) for i in range(n_splits): range_ = self.select_range( split=PosSel(i), set_=PosSel(j), split_group_by=split_group_by, set_group_by=set_group_by, merge_split_kwargs=dict(template_context=template_context), ) out[i, :] = self.get_range_mask(range_, template_context=template_context, **kwargs) yield out @property def iter_set_mask_arrs(self) -> tp.Iterator[tp.Array2d]: """`Splitter.get_iter_set_mask_arrs` with default arguments.""" return self.get_iter_set_mask_arrs() def get_iter_split_masks( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, **kwargs, ) -> tp.Iterator[tp.Frame]: """Generator of boolean DataFrames, one per split. Keyword arguments `**kwargs` are passed to `Splitter.get_iter_split_mask_arrs`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) for mask in self.get_iter_split_mask_arrs( split_group_by=split_group_by, set_group_by=set_group_by, **kwargs, ): yield pd.DataFrame(np.moveaxis(mask, -1, 0), index=self.index, columns=set_labels) @property def iter_split_masks(self) -> tp.Iterator[tp.Frame]: """`Splitter.get_iter_split_masks` with default arguments.""" return self.get_iter_split_masks() def get_iter_set_masks( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, **kwargs, ) -> tp.Iterator[tp.Frame]: """Generator of boolean DataFrames, one per set. Keyword arguments `**kwargs` are passed to `Splitter.get_iter_set_mask_arrs`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) for mask in self.get_iter_set_mask_arrs( split_group_by=split_group_by, set_group_by=set_group_by, **kwargs, ): yield pd.DataFrame(np.moveaxis(mask, -1, 0), index=self.index, columns=split_labels) @property def iter_set_masks(self) -> tp.Iterator[tp.Frame]: """`Splitter.get_iter_set_masks` with default arguments.""" return self.get_iter_set_masks() def get_mask_arr( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.SplitsMask: """Three-dimensional boolean array with splits. First axis represents splits. Second axis represents sets. Third axis represents index. Keyword arguments `**kwargs` are passed to `Splitter.get_iter_split_mask_arrs`.""" return np.array( list( self.get_iter_split_mask_arrs( split_group_by=split_group_by, set_group_by=set_group_by, template_context=template_context, **kwargs, ) ) ) @property def mask_arr(self) -> tp.SplitsMask: """`Splitter.get_mask_arr` with default arguments.""" return self.get_mask_arr() def get_mask( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_split: bool = True, squeeze_one_set: bool = True, index_combine_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Boolean Series/DataFrame where index is `Splitter.index` and columns are splits stacked together. Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`. !!! warning Boolean arrays for a big number of splits may take a considerable amount of memory.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs) out = np.moveaxis(mask_arr, -1, 0).reshape((len(self.index), -1)) one_split = len(split_labels) == 1 and squeeze_one_split one_set = len(set_labels) == 1 and squeeze_one_set if one_split and one_set: return pd.Series(out[:, 0], index=self.index) if one_split: return pd.DataFrame(out, index=self.index, columns=set_labels) if one_set: return pd.DataFrame(out, index=self.index, columns=split_labels) if index_combine_kwargs is None: index_combine_kwargs = {} new_columns = combine_indexes((split_labels, set_labels), **index_combine_kwargs) return pd.DataFrame(out, index=self.index, columns=new_columns) @property def mask(self) -> tp.Frame: """`Splitter.get_mask` with default arguments.""" return self.get_mask() def get_split_coverage( self, overlapping: bool = False, normalize: bool = True, relative: bool = False, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_split: bool = True, **kwargs, ) -> tp.MaybeSeries: """Get the coverage of each split mask. If `overlapping` is True, returns the number of overlapping True values between sets in each split. If `normalize` is True, returns the number of True values in each split relative to the length of the index. If `normalize` and `relative` are True, returns the number of True values in each split relative to the total number of True values across all splits. Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs) if overlapping: if normalize: coverage = (mask_arr.sum(axis=1) > 1).sum(axis=1) / mask_arr.any(axis=1).sum(axis=1) else: coverage = (mask_arr.sum(axis=1) > 1).sum(axis=1) else: if normalize: if relative: coverage = mask_arr.any(axis=1).sum(axis=1) / mask_arr.any(axis=(0, 1)).sum() else: coverage = mask_arr.any(axis=1).mean(axis=1) else: coverage = mask_arr.any(axis=1).sum(axis=1) one_split = len(split_labels) == 1 and squeeze_one_split if one_split: return coverage[0] return pd.Series(coverage, index=split_labels, name="split_coverage") @property def split_coverage(self) -> tp.Series: """`Splitter.get_split_coverage` with default arguments.""" return self.get_split_coverage() def get_set_coverage( self, overlapping: bool = False, normalize: bool = True, relative: bool = False, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_set: bool = True, **kwargs, ) -> tp.MaybeSeries: """Get the coverage of each set mask. If `overlapping` is True, returns the number of overlapping True values between splits in each set. If `normalize` is True, returns the number of True values in each set relative to the length of the index. If `normalize` and `relative` are True, returns the number of True values in each set relative to the total number of True values across all sets. Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs) if overlapping: if normalize: coverage = (mask_arr.sum(axis=0) > 1).sum(axis=1) / mask_arr.any(axis=0).sum(axis=1) else: coverage = (mask_arr.sum(axis=0) > 1).sum(axis=1) else: if normalize: if relative: coverage = mask_arr.any(axis=0).sum(axis=1) / mask_arr.any(axis=(0, 1)).sum() else: coverage = mask_arr.any(axis=0).mean(axis=1) else: coverage = mask_arr.any(axis=0).sum(axis=1) one_set = len(set_labels) == 1 and squeeze_one_set if one_set: return coverage[0] return pd.Series(coverage, index=set_labels, name="set_coverage") @property def set_coverage(self) -> tp.Series: """`Splitter.get_set_coverage` with default arguments.""" return self.get_set_coverage() def get_range_coverage( self, normalize: bool = True, relative: bool = False, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, squeeze_one_split: bool = True, squeeze_one_set: bool = True, index_combine_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get the coverage of each range mask. If `normalize` is True, returns the number of True values in each range relative to the length of the index. If `normalize` and `relative` are True, returns the number of True values in each range relative to the total number of True values in its split. Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs) if normalize: if relative: coverage = (mask_arr.sum(axis=2) / mask_arr.any(axis=1).sum(axis=1)[:, None]).flatten() else: coverage = (mask_arr.sum(axis=2) / mask_arr.shape[2]).flatten() else: coverage = mask_arr.sum(axis=2).flatten() one_split = len(split_labels) == 1 and squeeze_one_split one_set = len(set_labels) == 1 and squeeze_one_set if one_split and one_set: return coverage[0] if one_split: return pd.Series(coverage, index=set_labels, name="range_coverage") if one_set: return pd.Series(coverage, index=split_labels, name="range_coverage") if index_combine_kwargs is None: index_combine_kwargs = {} index = combine_indexes((split_labels, set_labels), **index_combine_kwargs) return pd.Series(coverage, index=index, name="range_coverage") @property def range_coverage(self) -> tp.Series: """`Splitter.get_range_coverage` with default arguments.""" return self.get_range_coverage() def get_coverage( self, overlapping: bool = False, normalize: bool = True, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, **kwargs, ) -> float: """Get the coverage of the entire mask. If `overlapping` is True, returns the number of overlapping True values. If `normalize` is True, returns the number of True values relative to the length of the index. If `overlapping` and `normalize` are True, returns the number of overlapping True values relative to the total number of True values. Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`.""" mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs) if overlapping: if normalize: return (mask_arr.sum(axis=(0, 1)) > 1).sum() / mask_arr.any(axis=(0, 1)).sum() return (mask_arr.sum(axis=(0, 1)) > 1).sum() if normalize: return mask_arr.any(axis=(0, 1)).mean() return mask_arr.any(axis=(0, 1)).sum() @property def coverage(self) -> float: """`Splitter.get_coverage` with default arguments.""" return self.get_coverage() def get_overlap_matrix( self, by: str = "split", normalize: bool = True, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, jitted: tp.JittedOption = None, squeeze_one_split: bool = True, squeeze_one_set: bool = True, index_combine_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Frame: """Get the overlap between each pair of ranges. The argument `by` can be one of 'split', 'set', and 'range'. If `normalize` is True, returns the number of True values in each overlap relative to the total number of True values in both ranges. Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`.""" split_group_by = self.get_split_grouper(split_group_by=split_group_by) split_labels = self.get_split_labels(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs) one_split = len(split_labels) == 1 and squeeze_one_split one_set = len(set_labels) == 1 and squeeze_one_set if by.lower() == "split": if normalize: func = jit_reg.resolve_option(nb.norm_split_overlap_matrix_nb, jitted) else: func = jit_reg.resolve_option(nb.split_overlap_matrix_nb, jitted) overlap_matrix = func(mask_arr) if one_split: return overlap_matrix[0, 0] index = split_labels elif by.lower() == "set": if normalize: func = jit_reg.resolve_option(nb.norm_set_overlap_matrix_nb, jitted) else: func = jit_reg.resolve_option(nb.set_overlap_matrix_nb, jitted) overlap_matrix = func(mask_arr) if one_set: return overlap_matrix[0, 0] index = set_labels elif by.lower() == "range": if normalize: func = jit_reg.resolve_option(nb.norm_range_overlap_matrix_nb, jitted) else: func = jit_reg.resolve_option(nb.range_overlap_matrix_nb, jitted) overlap_matrix = func(mask_arr) if one_split and one_set: return overlap_matrix[0, 0] if one_split: index = set_labels elif one_set: index = split_labels else: if index_combine_kwargs is None: index_combine_kwargs = {} index = combine_indexes((split_labels, set_labels), **index_combine_kwargs) else: raise ValueError(f"Invalid by: '{by}'") return pd.DataFrame(overlap_matrix, index=index, columns=index) @property def split_overlap_matrix(self) -> tp.Frame: """`Splitter.get_overlap_matrix` with `by="split"`.""" return self.get_overlap_matrix(by="split") @property def set_overlap_matrix(self) -> tp.Frame: """`Splitter.get_overlap_matrix` with `by="set"`.""" return self.get_overlap_matrix(by="set") @property def range_overlap_matrix(self) -> tp.Frame: """`Splitter.get_overlap_matrix` with `by="range"`.""" return self.get_overlap_matrix(by="range") # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Splitter.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.splitter`.""" from vectorbtpro._settings import settings splitter_stats_cfg = settings["splitter"]["stats"] return merge_dicts(Analyzable.stats_defaults.__get__(self), splitter_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start=dict( title="Index Start", calc_func=lambda self: self.index[0], agg_func=None, tags=["splitter", "index"], ), end=dict( title="Index End", calc_func=lambda self: self.index[-1], agg_func=None, tags=["splitter", "index"], ), period=dict( title="Index Length", calc_func=lambda self: len(self.index), agg_func=None, tags=["splitter", "index"], ), split_count=dict( title="Splits", calc_func="n_splits", agg_func=None, tags=["splitter", "splits"], ), set_count=dict( title="Sets", calc_func="n_sets", agg_func=None, tags=["splitter", "splits"], ), coverage=dict( title=RepFunc(lambda normalize: "Coverage [%]" if normalize else "Coverage"), calc_func="coverage", overlapping=False, post_calc_func=lambda self, out, settings: out * 100 if settings["normalize"] else out, agg_func=None, tags=["splitter", "splits", "coverage"], ), set_coverage=dict( title=RepFunc(lambda normalize: "Coverage [%]" if normalize else "Coverage"), check_has_multiple_sets=True, calc_func="set_coverage", overlapping=False, relative=False, post_calc_func=lambda self, out, settings: to_dict( out * 100 if settings["normalize"] else out, orient="index_series" ), agg_func=None, tags=["splitter", "splits", "coverage"], ), set_mean_rel_coverage=dict( title="Mean Rel Coverage [%]", check_has_multiple_sets=True, check_normalize=True, calc_func="range_coverage", relative=True, post_calc_func=lambda self, out, settings: to_dict( out.groupby(self.get_set_labels(set_group_by=settings.get("set_group_by", None)).names).mean()[ self.get_set_labels(set_group_by=settings.get("set_group_by", None)) ] * 100, orient="index_series", ), agg_func=None, tags=["splitter", "splits", "coverage"], ), overlap_coverage=dict( title=RepFunc(lambda normalize: "Overlap Coverage [%]" if normalize else "Overlap Coverage"), calc_func="coverage", overlapping=True, post_calc_func=lambda self, out, settings: out * 100 if settings["normalize"] else out, agg_func=None, tags=["splitter", "splits", "coverage"], ), set_overlap_coverage=dict( title=RepFunc(lambda normalize: "Overlap Coverage [%]" if normalize else "Overlap Coverage"), check_has_multiple_sets=True, calc_func="set_coverage", overlapping=True, post_calc_func=lambda self, out, settings: to_dict( out * 100 if settings["normalize"] else out, orient="index_series" ), agg_func=None, tags=["splitter", "splits", "coverage"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, mask_kwargs: tp.KwargsLike = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot splits as rows and sets as colors. Args: split_group_by (any): Split groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. set_group_by (any): Set groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. mask_kwargs (dict): Keyword arguments passed to `Splitter.get_iter_set_masks`. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`. Can be a sequence, one per set. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: * Plot a scikit-learn splitter: ```pycon >>> from vectorbtpro import * >>> from sklearn.model_selection import TimeSeriesSplit >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_sklearn(index, TimeSeriesSplit()) >>> splitter.plot().show() ``` ![](/assets/images/api/Splitter.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Splitter.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure import plotly.express as px if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) split_group_by = self.get_split_grouper(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) if fig.layout.colorway is not None: colorway = fig.layout.colorway else: colorway = fig.layout.template.layout.colorway if len(set_labels) > len(colorway): colorway = px.colors.qualitative.Alphabet if self.get_n_splits(split_group_by=split_group_by) > 0: if self.get_n_sets(set_group_by=set_group_by) > 0: if mask_kwargs is None: mask_kwargs = {} for i, mask in enumerate( self.get_iter_set_masks( split_group_by=split_group_by, set_group_by=set_group_by, **mask_kwargs, ) ): df = mask.vbt.wrapper.fill() df[mask] = i color = adjust_opacity(colorway[i % len(colorway)], 0.8) trace_name = str(set_labels[i]) _trace_kwargs = merge_dicts( dict( showscale=False, showlegend=True, legendgroup=str(set_labels[i]), name=trace_name, colorscale=[color, color], hovertemplate="%{x}
Split: %{y}
Set: " + trace_name, ), resolve_dict(trace_kwargs, i=i), ) fig = df.vbt.ts_heatmap( trace_kwargs=_trace_kwargs, add_trace_kwargs=add_trace_kwargs, is_y_category=True, fig=fig, ) return fig def plot_coverage( self, stacked: bool = True, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, mask_kwargs: tp.KwargsLike = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot index as rows and sets as lines. Args: stacked (bool): Whether to plot as an area plot. split_group_by (any): Split groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. set_group_by (any): Set groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. mask_kwargs (dict): Keyword arguments passed to `Splitter.get_iter_set_masks`. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`. Can be a sequence, one per set. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: * Area plot: ```pycon >>> from vectorbtpro import * >>> from sklearn.model_selection import TimeSeriesSplit >>> index = pd.date_range("2020", "2021", freq="D") >>> splitter = vbt.Splitter.from_sklearn(index, TimeSeriesSplit()) >>> splitter.plot_coverage().show() ``` ![](/assets/images/api/Splitter_coverage_area.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Splitter_coverage_area.dark.svg#only-dark){: .iimg loading=lazy } * Line plot: ```pycon >>> splitter.plot_coverage(stacked=False).show() ``` ![](/assets/images/api/Splitter_coverage_line.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Splitter_coverage_line.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure import plotly.express as px if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) split_group_by = self.get_split_grouper(split_group_by=split_group_by) set_group_by = self.get_set_grouper(set_group_by=set_group_by) set_labels = self.get_set_labels(set_group_by=set_group_by) if fig.layout.colorway is not None: colorway = fig.layout.colorway else: colorway = fig.layout.template.layout.colorway if len(set_labels) > len(colorway): colorway = px.colors.qualitative.Alphabet if self.get_n_splits(split_group_by=split_group_by) > 0: if self.get_n_sets(set_group_by=set_group_by) > 0: if mask_kwargs is None: mask_kwargs = {} for i, mask in enumerate( self.get_iter_set_masks( split_group_by=split_group_by, set_group_by=set_group_by, **mask_kwargs, ) ): _trace_kwargs = merge_dicts( dict( stackgroup="coverage" if stacked else None, legendgroup=str(set_labels[i]), name=str(set_labels[i]), line=dict(color=colorway[i % len(colorway)], shape="hv"), ), resolve_dict(trace_kwargs, i=i), ) fig = mask.sum(axis=1).vbt.lineplot( trace_kwargs=_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Splitter.plots`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.splitter`.""" from vectorbtpro._settings import settings splitter_plots_cfg = settings["splitter"]["plots"] return merge_dicts(Analyzable.plots_defaults.__get__(self), splitter_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( title="Splits", yaxis_kwargs=dict(title="Split"), plot_func="plot", tags="splitter", ), plot_coverage=dict( title="Coverage", yaxis_kwargs=dict(title="Count"), plot_func="plot_coverage", tags="splitter", ), ) ) @property def subplots(self) -> Config: return self._subplots Splitter.override_metrics_doc(__pdoc__) Splitter.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Decorators for splitting.""" import inspect from functools import wraps from vectorbtpro import _typing as tp from vectorbtpro.generic.splitting.base import Splitter, Takeable from vectorbtpro.utils import checks from vectorbtpro.utils.config import FrozenConfig, merge_dicts from vectorbtpro.utils.execution import NoResult, NoResultsException from vectorbtpro.utils.params import parameterized from vectorbtpro.utils.parsing import ( annotate_args, flatten_ann_args, unflatten_ann_args, ann_args_to_args, match_ann_arg, get_func_arg_names, ) from vectorbtpro.utils.template import Rep, RepEval, substitute_templates __all__ = [ "split", "cv_split", ] def split( *args, splitter: tp.Union[None, str, Splitter, tp.Callable] = None, splitter_cls: tp.Optional[tp.Type[Splitter]] = None, splitter_kwargs: tp.KwargsLike = None, index: tp.Optional[tp.IndexLike] = None, index_from: tp.Optional[tp.AnnArgQuery] = None, takeable_args: tp.Optional[tp.MaybeIterable[tp.AnnArgQuery]] = None, template_context: tp.KwargsLike = None, forward_kwargs_as: tp.KwargsLike = None, return_splitter: bool = False, apply_kwargs: tp.KwargsLike = None, **var_kwargs, ) -> tp.Callable: """Decorator that splits the inputs of a function. Does the following: 1. Resolves the splitter of the type `vectorbtpro.generic.splitting.base.Splitter` using the argument `splitter`. It can be either an already provided splitter instance, the name of a factory method (such as "from_n_rolling"), or the factory method itself. If `splitter` is None, the right method will be guessed based on the supplied arguments using `vectorbtpro.generic.splitting.base.Splitter.guess_method`. To construct a splitter, it will pass `index` and `**splitter_kwargs`. Index is getting resolved either using an already provided `index`, by parsing the argument under a name/position provided in `index_from`, or by parsing the first argument from `takeable_args` (in this order). 2. Wraps arguments in `takeable_args` with `vectorbtpro.generic.splitting.base.Takeable` 3. Runs `vectorbtpro.generic.splitting.base.Splitter.apply` with arguments passed to the function as `args` and `kwargs`, but also `apply_kwargs` (the ones passed to the decorator) Keyword arguments `splitter_kwargs` are passed to the factory method. Keyword arguments `apply_kwargs` are passed to `vectorbtpro.generic.splitting.base.Splitter.apply`. If variable keyword arguments are provided, they will be used as `splitter_kwargs` if `apply_kwargs` is already set, and vice versa. If `splitter_kwargs` and `apply_kwargs` aren't set, they will be used as `splitter_kwargs` if a splitter instance hasn't been built yet, otherwise as `apply_kwargs`. If both arguments are set, will raise an error. Usage: * Split a Series and return its sum: ```pycon >>> from vectorbtpro import * >>> @vbt.split( ... splitter="from_n_rolling", ... splitter_kwargs=dict(n=2), ... takeable_args=["sr"] ... ) ... def f(sr): ... return sr.sum() >>> index = pd.date_range("2020-01-01", "2020-01-06") >>> sr = pd.Series(np.arange(len(index)), index=index) >>> f(sr) split 0 3 1 12 dtype: int64 ``` * Perform a split manually: ```pycon >>> @vbt.split( ... splitter="from_n_rolling", ... splitter_kwargs=dict(n=2), ... takeable_args=["index"] ... ) ... def f(index, sr): ... return sr[index].sum() >>> f(index, sr) split 0 3 1 12 dtype: int64 ``` * Construct splitter and mark arguments as "takeable" manually: ```pycon >>> splitter = vbt.Splitter.from_n_rolling(index, n=2) >>> @vbt.split(splitter=splitter) ... def f(sr): ... return sr.sum() >>> f(vbt.Takeable(sr)) split 0 3 1 12 dtype: int64 ``` * Split multiple timeframes using a custom index: ```pycon >>> @vbt.split( ... splitter="from_n_rolling", ... splitter_kwargs=dict(n=2), ... index=index, ... takeable_args=["h12_sr", "d2_sr"] ... ) ... def f(h12_sr, d2_sr): ... return h12_sr.sum() + d2_sr.sum() >>> h12_index = pd.date_range("2020-01-01", "2020-01-06", freq="12H") >>> d2_index = pd.date_range("2020-01-01", "2020-01-06", freq="2D") >>> h12_sr = pd.Series(np.arange(len(h12_index)), index=h12_index) >>> d2_sr = pd.Series(np.arange(len(d2_index)), index=d2_index) >>> f(h12_sr, d2_sr) split 0 15 1 42 dtype: int64 ``` """ def decorator(func: tp.Callable) -> tp.Callable: @wraps(func) def wrapper(*args, **kwargs) -> tp.Any: splitter = kwargs.pop("_splitter", wrapper.options["splitter"]) splitter_cls = kwargs.pop("_splitter_cls", wrapper.options["splitter_cls"]) splitter_kwargs = merge_dicts(wrapper.options["splitter_kwargs"], kwargs.pop("_splitter_kwargs", {})) index = kwargs.pop("_index", wrapper.options["index"]) index_from = kwargs.pop("_index_from", wrapper.options["index_from"]) takeable_args = kwargs.pop("_takeable_args", wrapper.options["takeable_args"]) if takeable_args is None: takeable_args = set() elif checks.is_iterable(takeable_args) and not isinstance(takeable_args, str): takeable_args = set(takeable_args) else: takeable_args = {takeable_args} template_context = merge_dicts(wrapper.options["template_context"], kwargs.pop("_template_context", {})) apply_kwargs = merge_dicts(wrapper.options["apply_kwargs"], kwargs.pop("_apply_kwargs", {})) return_splitter = kwargs.pop("_return_splitter", wrapper.options["return_splitter"]) forward_kwargs_as = merge_dicts(wrapper.options["forward_kwargs_as"], kwargs.pop("_forward_kwargs_as", {})) if len(forward_kwargs_as) > 0: new_kwargs = dict() for k, v in kwargs.items(): if k in forward_kwargs_as: new_kwargs[forward_kwargs_as.pop(k)] = v else: new_kwargs[k] = v kwargs = new_kwargs if len(forward_kwargs_as) > 0: for k, v in forward_kwargs_as.items(): kwargs[v] = locals()[k] if splitter_cls is None: splitter_cls = Splitter if len(var_kwargs) > 0: var_splitter_kwargs = {} var_apply_kwargs = {} if splitter is None or not isinstance(splitter, splitter_cls): apply_arg_names = get_func_arg_names(splitter_cls.apply) if splitter is not None: if isinstance(splitter, str): splitter_arg_names = get_func_arg_names(getattr(splitter_cls, splitter)) else: splitter_arg_names = get_func_arg_names(splitter) for k, v in var_kwargs.items(): assigned = False if k in splitter_arg_names: var_splitter_kwargs[k] = v assigned = True if k != "split" and k in apply_arg_names: var_apply_kwargs[k] = v assigned = True if not assigned: raise ValueError(f"Argument '{k}' couldn't be assigned") else: for k, v in var_kwargs.items(): if k == "freq": var_splitter_kwargs[k] = v var_apply_kwargs[k] = v elif k == "split" or k not in apply_arg_names: var_splitter_kwargs[k] = v else: var_apply_kwargs[k] = v else: var_apply_kwargs = var_kwargs splitter_kwargs = merge_dicts(var_splitter_kwargs, splitter_kwargs) apply_kwargs = merge_dicts(var_apply_kwargs, apply_kwargs) if len(splitter_kwargs) > 0: if splitter is None: splitter = splitter_cls.guess_method(**splitter_kwargs) if splitter is None: raise ValueError("Splitter method couldn't be guessed") else: if splitter is None: raise ValueError("Must provide splitter or splitter method") if not isinstance(splitter, splitter_cls) and index is not None: if isinstance(splitter, str): splitter = getattr(splitter_cls, splitter) splitter = splitter(index, template_context=template_context, **splitter_kwargs) if return_splitter: return splitter ann_args = annotate_args(func, args, kwargs, attach_annotations=True) flat_ann_args = flatten_ann_args(ann_args) if isinstance(splitter, splitter_cls): flat_ann_args = splitter.parse_and_inject_takeables(flat_ann_args) else: flat_ann_args = splitter_cls.parse_and_inject_takeables(flat_ann_args) for k, v in flat_ann_args.items(): if isinstance(v["value"], Takeable): takeable_args.add(k) for takeable_arg in takeable_args: arg_name = match_ann_arg(ann_args, takeable_arg, return_name=True) if not isinstance(flat_ann_args[arg_name]["value"], Takeable): flat_ann_args[arg_name]["value"] = Takeable(flat_ann_args[arg_name]["value"]) new_ann_args = unflatten_ann_args(flat_ann_args) args, kwargs = ann_args_to_args(new_ann_args) if not isinstance(splitter, splitter_cls): if index is None and index_from is not None: index = splitter_cls.get_obj_index(match_ann_arg(ann_args, index_from)) if index is None and len(takeable_args) > 0: first_takeable = match_ann_arg(ann_args, list(takeable_args)[0]) if isinstance(first_takeable, Takeable): first_takeable = first_takeable.obj index = splitter_cls.get_obj_index(first_takeable) if index is None: raise ValueError("Must provide splitter, index, index_from, or takeable_args") if isinstance(splitter, str): splitter = getattr(splitter_cls, splitter) splitter = splitter(index, template_context=template_context, **splitter_kwargs) if return_splitter: return splitter return splitter.apply( func, *args, **kwargs, **apply_kwargs, ) wrapper.func = func wrapper.name = func.__name__ wrapper.is_split = True wrapper.options = FrozenConfig( splitter=splitter, splitter_cls=splitter_cls, splitter_kwargs=splitter_kwargs, index=index, index_from=index_from, takeable_args=takeable_args, template_context=template_context, forward_kwargs_as=forward_kwargs_as, return_splitter=return_splitter, apply_kwargs=apply_kwargs, var_kwargs=var_kwargs, ) signature = inspect.signature(wrapper) lists_var_kwargs = False for k, v in signature.parameters.items(): if v.kind == v.VAR_KEYWORD: lists_var_kwargs = True break if not lists_var_kwargs: var_kwargs_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD) new_parameters = tuple(signature.parameters.values()) + (var_kwargs_param,) wrapper.__signature__ = signature.replace(parameters=new_parameters) return wrapper if len(args) == 0: return decorator elif len(args) == 1: return decorator(args[0]) raise ValueError("Either function or keyword arguments must be passed") def cv_split( *args, parameterized_kwargs: tp.KwargsLike = None, selection: tp.Union[str, tp.Selection] = "max", return_grid: tp.Union[bool, str] = False, skip_errored: bool = False, raise_no_results: bool = True, template_context: tp.KwargsLike = None, **split_kwargs, ) -> tp.Callable: """Decorator that combines `split` and `vectorbtpro.utils.params.parameterized` for cross-validation. Creates a new apply function that is going to be decorated with `split` and thus applied at each single range using `vectorbtpro.generic.splitting.base.Splitter.apply`. Inside this apply function, there is a test whether the current range belongs to the first (training) set. If yes, parameterizes the underlying function and runs it on the entire grid of parameters. The returned results are then stored in a global list. These results are then read by the other (testing) sets in the same split. If `selection` is a template, it can evaluate the grid results (available as `grid_results`) and return the best parameter combination. This parameter combination is then executed by each set (including training). Argument `selection` also accepts "min" for `np.argmin` and "max" for `np.argmax`. Keyword arguments `parameterized_kwargs` will be passed to `vectorbtpro.utils.params.parameterized` and will have their templates substituted with a context that will also include the split-related context (including `split_idx`, `set_idx`, etc., see `vectorbtpro.generic.splitting.base.Splitter.apply`). If `return_grid` is True or 'first', returns both the grid and the selection. If `return_grid` is 'all', executes the grid on each set and returns along with the selection. Otherwise, returns only the selection. If `vectorbtpro.utils.execution.NoResultsException` is raised or `skip_errored` is True and any exception is raised, will skip the current iteration and remove it from the final index. Usage: * Permutate a series and pick the first value. Make the seed parameterizable. Cross-validate based on the highest picked value: ```pycon >>> from vectorbtpro import * >>> @vbt.cv_split( ... splitter="from_n_rolling", ... splitter_kwargs=dict(n=3, split=0.5), ... takeable_args=["sr"], ... merge_func="concat", ... ) ... def f(sr, seed): ... np.random.seed(seed) ... return np.random.permutation(sr)[0] >>> index = pd.date_range("2020-01-01", "2020-02-01") >>> np.random.seed(0) >>> sr = pd.Series(np.random.permutation(np.arange(len(index))), index=index) >>> f(sr, vbt.Param([41, 42, 43])) split set seed 0 set_0 41 22 set_1 41 28 1 set_0 43 8 set_1 43 31 2 set_0 43 19 set_1 43 0 dtype: int64 ``` * Extend the example above to also return the grid results of each set: ```pycon >>> f(sr, vbt.Param([41, 42, 43]), _return_grid="all") (split set seed 0 set_0 41 22 42 22 43 2 set_1 41 28 42 28 43 20 1 set_0 41 5 42 5 43 8 set_1 41 23 42 23 43 31 2 set_0 41 18 42 18 43 19 set_1 41 27 42 27 43 0 dtype: int64, split set seed 0 set_0 41 22 set_1 41 28 1 set_0 43 8 set_1 43 31 2 set_0 43 19 set_1 43 0 dtype: int64) ``` """ def decorator(func: tp.Callable) -> tp.Callable: if getattr(func, "is_split", False) or getattr(func, "is_parameterized", False): raise ValueError("Function is already decorated with split or parameterized") @wraps(func) def wrapper(*args, **kwargs) -> tp.Any: parameterized_kwargs = merge_dicts( wrapper.options["parameterized_kwargs"], kwargs.pop("_parameterized_kwargs", {}), ) selection = kwargs.pop("_selection", wrapper.options["selection"]) if isinstance(selection, str) and selection.lower() == "min": selection = RepEval("[np.nanargmin(grid_results)]") if isinstance(selection, str) and selection.lower() == "max": selection = RepEval("[np.nanargmax(grid_results)]") return_grid = kwargs.pop("_return_grid", wrapper.options["return_grid"]) if isinstance(return_grid, bool): if return_grid: return_grid = "first" else: return_grid = None skip_errored = kwargs.pop("_skip_errored", wrapper.options["skip_errored"]) template_context = merge_dicts( wrapper.options["template_context"], kwargs.pop("_template_context", {}), ) split_kwargs = merge_dicts( wrapper.options["split_kwargs"], kwargs.pop("_split_kwargs", {}), ) if "merge_func" in split_kwargs and "merge_func" not in parameterized_kwargs: parameterized_kwargs["merge_func"] = split_kwargs["merge_func"] if "show_progress" not in parameterized_kwargs: parameterized_kwargs["show_progress"] = False all_grid_results = [] @wraps(func) def apply_wrapper(*_args, __template_context=None, **_kwargs): try: __template_context = dict(__template_context) __template_context["all_grid_results"] = all_grid_results _parameterized_kwargs = substitute_templates( parameterized_kwargs, __template_context, eval_id="parameterized_kwargs", ) parameterized_func = parameterized( func, template_context=__template_context, **_parameterized_kwargs, ) if __template_context["set_idx"] == 0: try: grid_results = parameterized_func(*_args, **_kwargs) all_grid_results.append(grid_results) except Exception as e: if skip_errored or isinstance(e, NoResultsException): all_grid_results.append(NoResult) raise e if all_grid_results[-1] is NoResult: if raise_no_results: raise NoResultsException return NoResult result = parameterized_func( *_args, _selection=selection, _template_context=dict(grid_results=all_grid_results[-1]), **_kwargs, ) if return_grid is not None: if return_grid.lower() == "first": return all_grid_results[-1], result if return_grid.lower() == "all": grid_results = parameterized_func( *_args, _template_context=dict(grid_results=all_grid_results[-1]), **_kwargs, ) return grid_results, result else: raise ValueError(f"Invalid return_grid: '{return_grid}'") return result except Exception as e: if skip_errored or isinstance(e, NoResultsException): return NoResult raise e signature = inspect.signature(apply_wrapper) lists_var_kwargs = False for k, v in signature.parameters.items(): if v.kind == v.VAR_KEYWORD: lists_var_kwargs = True break if not lists_var_kwargs: var_kwargs_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD) new_parameters = tuple(signature.parameters.values()) + (var_kwargs_param,) apply_wrapper.__signature__ = signature.replace(parameters=new_parameters) split_func = split(apply_wrapper, template_context=template_context, **split_kwargs) return split_func(*args, __template_context=Rep("context", eval_id="apply_kwargs"), **kwargs) wrapper.func = func wrapper.name = func.__name__ wrapper.is_parameterized = True wrapper.is_split = True wrapper.options = FrozenConfig( parameterized_kwargs=parameterized_kwargs, selection=selection, return_grid=return_grid, skip_errored=skip_errored, template_context=template_context, split_kwargs=split_kwargs, ) signature = inspect.signature(wrapper) lists_var_kwargs = False for k, v in signature.parameters.items(): if v.kind == v.VAR_KEYWORD: lists_var_kwargs = True break if not lists_var_kwargs: var_kwargs_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD) new_parameters = tuple(signature.parameters.values()) + (var_kwargs_param,) wrapper.__signature__ = signature.replace(parameters=new_parameters) return wrapper if len(args) == 0: return decorator elif len(args) == 1: return decorator(args[0]) raise ValueError("Either function or keyword arguments must be passed") # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for splitting.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.registries.jit_registry import register_jitted __all__ = [] @register_jitted(cache=True, tags={"can_parallel"}) def split_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d: """Compute the overlap matrix for splits.""" out = np.empty((mask_arr.shape[0], mask_arr.shape[0]), dtype=int_) temp_mask = np.empty((mask_arr.shape[0], mask_arr.shape[2]), dtype=np.bool_) for i in range(mask_arr.shape[0]): for m in range(mask_arr.shape[2]): if mask_arr[i, :, m].any(): temp_mask[i, m] = True else: temp_mask[i, m] = False for i1 in prange(mask_arr.shape[0]): for i2 in range(mask_arr.shape[0]): intersection = (temp_mask[i1] & temp_mask[i2]).sum() out[i1, i2] = intersection return out @register_jitted(cache=True, tags={"can_parallel"}) def norm_split_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d: """Compute the normalized overlap matrix for splits.""" out = np.empty((mask_arr.shape[0], mask_arr.shape[0]), dtype=float_) temp_mask = np.empty((mask_arr.shape[0], mask_arr.shape[2]), dtype=np.bool_) for i in range(mask_arr.shape[0]): for m in range(mask_arr.shape[2]): if mask_arr[i, :, m].any(): temp_mask[i, m] = True else: temp_mask[i, m] = False for i1 in prange(mask_arr.shape[0]): for i2 in range(mask_arr.shape[0]): intersection = (temp_mask[i1] & temp_mask[i2]).sum() union = (temp_mask[i1] | temp_mask[i2]).sum() out[i1, i2] = intersection / union return out @register_jitted(cache=True, tags={"can_parallel"}) def set_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d: """Compute the overlap matrix for sets.""" out = np.empty((mask_arr.shape[1], mask_arr.shape[1]), dtype=int_) temp_mask = np.empty((mask_arr.shape[1], mask_arr.shape[2]), dtype=np.bool_) for j in range(mask_arr.shape[1]): for m in range(mask_arr.shape[2]): if mask_arr[:, j, m].any(): temp_mask[j, m] = True else: temp_mask[j, m] = False for j1 in prange(mask_arr.shape[1]): for j2 in range(mask_arr.shape[1]): intersection = (temp_mask[j1] & temp_mask[j2]).sum() out[j1, j2] = intersection return out @register_jitted(cache=True, tags={"can_parallel"}) def norm_set_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d: """Compute the normalized overlap matrix for sets.""" out = np.empty((mask_arr.shape[1], mask_arr.shape[1]), dtype=float_) temp_mask = np.empty((mask_arr.shape[1], mask_arr.shape[2]), dtype=np.bool_) for j in range(mask_arr.shape[1]): for m in range(mask_arr.shape[2]): if mask_arr[:, j, m].any(): temp_mask[j, m] = True else: temp_mask[j, m] = False for j1 in prange(mask_arr.shape[1]): for j2 in range(mask_arr.shape[1]): intersection = (temp_mask[j1] & temp_mask[j2]).sum() union = (temp_mask[j1] | temp_mask[j2]).sum() out[j1, j2] = intersection / union return out @register_jitted(cache=True, tags={"can_parallel"}) def range_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d: """Compute the overlap matrix for ranges.""" out = np.empty((mask_arr.shape[0] * mask_arr.shape[1], mask_arr.shape[0] * mask_arr.shape[1]), dtype=int_) for k in prange(mask_arr.shape[0] * mask_arr.shape[1]): i1 = k // mask_arr.shape[1] j1 = k % mask_arr.shape[1] for l in range(mask_arr.shape[0] * mask_arr.shape[1]): i2 = l // mask_arr.shape[1] j2 = l % mask_arr.shape[1] intersection = (mask_arr[i1, j1] & mask_arr[i2, j2]).sum() out[k, l] = intersection return out @register_jitted(cache=True, tags={"can_parallel"}) def norm_range_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d: """Compute the normalized overlap matrix for ranges.""" out = np.empty((mask_arr.shape[0] * mask_arr.shape[1], mask_arr.shape[0] * mask_arr.shape[1]), dtype=float_) for k in prange(mask_arr.shape[0] * mask_arr.shape[1]): i1 = k // mask_arr.shape[1] j1 = k % mask_arr.shape[1] for l in range(mask_arr.shape[0] * mask_arr.shape[1]): i2 = l // mask_arr.shape[1] j2 = l % mask_arr.shape[1] intersection = (mask_arr[i1, j1] & mask_arr[i2, j2]).sum() union = (mask_arr[i1, j1] | mask_arr[i2, j2]).sum() out[k, l] = intersection / union return out @register_jitted(cache=True) def split_range_by_gap_nb(range_: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Split a range with gaps into start and end indices.""" if len(range_) == 0: raise ValueError("Range is empty") start_idxs_out = np.empty(len(range_), dtype=int_) stop_idxs_out = np.empty(len(range_), dtype=int_) start_idxs_out[0] = 0 k = 0 for i in range(1, len(range_)): if range_[i] - range_[i - 1] != 1: stop_idxs_out[k] = i k += 1 start_idxs_out[k] = i stop_idxs_out[k] = len(range_) return start_idxs_out[: k + 1], stop_idxs_out[: k + 1] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== # # MIT License # # Copyright (c) 2018 Samuel Monnier # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """Classes for purged cross-validation in time series. As described in Advances in Financial Machine Learning, Marcos Lopez de Prado, 2018.""" from abc import abstractmethod from itertools import combinations import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.base import Base __all__ = [ "PurgedWalkForwardCV", "PurgedKFoldCV", ] class BasePurgedCV(Base): """Abstract class for purged time series cross-validation. Time series cross-validation requires each sample has a prediction time, at which the features are used to predict the response, and an evaluation time, at which the response is known and the error can be computed. Importantly, it means that unlike in standard sklearn cross-validation, the samples X, response y, `pred_times` and `eval_times` must all be pandas DataFrames/Series having the same index. It is also assumed that the samples are time-ordered with respect to the prediction time.""" def __init__(self, n_folds: int = 10, purge_td: tp.TimedeltaLike = 0) -> None: self._n_folds = n_folds self._pred_times = None self._eval_times = None self._indices = None self._purge_td = dt.to_timedelta(purge_td) @property def n_folds(self) -> int: """Number of folds.""" return self._n_folds @property def purge_td(self) -> pd.Timedelta: """Purge period.""" return self._purge_td @property def pred_times(self) -> tp.Optional[pd.Series]: """Times at which predictions are made.""" return self._pred_times @property def eval_times(self) -> tp.Optional[pd.Series]: """Times at which the response becomes available and the error can be computed.""" return self._eval_times @property def indices(self) -> tp.Optional[tp.Array1d]: """Indices.""" return self._indices def purge( self, train_indices: tp.Array1d, test_fold_start: int, test_fold_end: int, ) -> tp.Array1d: """Purge part of the train set. Given a left boundary index `test_fold_start` of the test set and a right boundary index `test_fold_end`, this method removes from the train set all the samples whose evaluation time is posterior to the prediction time of the first test sample after the boundary.""" time_test_fold_start = self.pred_times.iloc[test_fold_start] eval_times = self.eval_times + self.purge_td train_indices_1 = np.intersect1d(train_indices, self.indices[eval_times < time_test_fold_start]) train_indices_2 = np.intersect1d(train_indices, self.indices[test_fold_end:]) return np.concatenate((train_indices_1, train_indices_2)) @abstractmethod def split( self, X: tp.SeriesFrame, y: tp.Optional[tp.Series] = None, pred_times: tp.Union[None, tp.Index, tp.Series] = None, eval_times: tp.Union[None, tp.Index, tp.Series] = None, ): """Yield the indices of the train and test sets.""" checks.assert_instance_of(X, (pd.Series, pd.DataFrame), arg_name="X") if y is not None: checks.assert_instance_of(y, pd.Series, arg_name="y") if pred_times is None: pred_times = X.index if isinstance(pred_times, pd.Index): pred_times = pd.Series(pred_times, index=X.index) else: checks.assert_instance_of(pred_times, pd.Series, arg_name="pred_times") checks.assert_index_equal(X.index, pred_times.index, check_names=False) if eval_times is None: eval_times = X.index if isinstance(eval_times, pd.Index): eval_times = pd.Series(eval_times, index=X.index) else: checks.assert_instance_of(eval_times, pd.Series, arg_name="eval_times") checks.assert_index_equal(X.index, eval_times.index, check_names=False) self._pred_times = pred_times self._eval_times = eval_times self._indices = np.arange(X.shape[0]) class PurgedWalkForwardCV(BasePurgedCV): """Purged walk-forward cross-validation. The samples are decomposed into `n_folds` folds containing equal numbers of samples, without shuffling. In each cross validation round, `n_test_folds` contiguous folds are used as the test set, while the train set consists in between `min_train_folds` and `max_train_folds` immediately preceding folds. Each sample should be tagged with a prediction time and an evaluation time. The split is such that the intervals [`pred_times`, `eval_times`] associated to samples in the train and test set do not overlap. (The overlapping samples are dropped.) With `split_by_time=True` in the `PurgedWalkForwardCV.split` method, it is also possible to split the samples in folds spanning equal time intervals (using the prediction time as a time tag), instead of folds containing equal numbers of samples.""" def __init__( self, n_folds: int = 10, n_test_folds: int = 1, min_train_folds: int = 2, max_train_folds: tp.Optional[int] = None, split_by_time: bool = False, purge_td: tp.TimedeltaLike = 0, ) -> None: BasePurgedCV.__init__(self, n_folds=n_folds, purge_td=purge_td) if n_test_folds >= self.n_folds - 1: raise ValueError("n_test_folds must be between 1 and n_folds - 1") self._n_test_folds = n_test_folds if min_train_folds >= self.n_folds - self.n_test_folds: raise ValueError("min_train_folds must be between 1 and n_folds - n_test_folds") self._min_train_folds = min_train_folds if max_train_folds is None: max_train_folds = self.n_folds - self.n_test_folds if max_train_folds > self.n_folds - self.n_test_folds: raise ValueError("max_train_split must be between 1 and n_folds - n_test_folds") self._max_train_folds = max_train_folds self._split_by_time = split_by_time self._fold_bounds = [] @property def n_test_folds(self) -> int: """Number of folds used in the test set.""" return self._n_test_folds @property def min_train_folds(self) -> int: """Minimal number of folds to be used in the train set.""" return self._min_train_folds @property def max_train_folds(self) -> int: """Maximal number of folds to be used in the train set.""" return self._max_train_folds @property def split_by_time(self) -> int: """Whether the folds span identical time intervals. Otherwise, the folds contain an (approximately) equal number of samples.""" return self._split_by_time @property def fold_bounds(self) -> tp.List[int]: """Fold boundaries.""" return self._fold_bounds def compute_fold_bounds(self) -> tp.List[int]: """Compute a list containing the fold (left) boundaries.""" if self.split_by_time: full_time_span = self.pred_times.max() - self.pred_times.min() fold_time_span = full_time_span / self.n_folds fold_bounds_times = [self.pred_times.iloc[0] + fold_time_span * n for n in range(self.n_folds)] return self.pred_times.searchsorted(fold_bounds_times) else: return [fold[0] for fold in np.array_split(self.indices, self.n_folds)] def compute_train_set(self, fold_bound: int, count_folds: int) -> tp.Array1d: """Compute the position indices of the samples in the train set.""" if count_folds > self.max_train_folds: start_train = self.fold_bounds[count_folds - self.max_train_folds] else: start_train = 0 train_indices = np.arange(start_train, fold_bound) train_indices = self.purge(train_indices, fold_bound, self.indices[-1]) return train_indices def compute_test_set(self, fold_bound: int, count_folds: int) -> tp.Array1d: """Compute the position indices of the samples in the test set.""" if self.n_folds - count_folds > self.n_test_folds: end_test = self.fold_bounds[count_folds + self.n_test_folds] else: end_test = self.indices[-1] + 1 return np.arange(fold_bound, end_test) def split( self, X: tp.SeriesFrame, y: tp.Optional[tp.Series] = None, pred_times: tp.Union[None, tp.Index, tp.Series] = None, eval_times: tp.Union[None, tp.Index, tp.Series] = None, ) -> tp.Iterable[tp.Tuple[tp.Array1d, tp.Array1d]]: BasePurgedCV.split(self, X, y, pred_times=pred_times, eval_times=eval_times) self._fold_bounds = self.compute_fold_bounds() count_folds = 0 for fold_bound in self.fold_bounds: if count_folds < self.min_train_folds: count_folds = count_folds + 1 continue if self.n_folds - count_folds < self.n_test_folds: break test_indices = self.compute_test_set(fold_bound, count_folds) train_indices = self.compute_train_set(fold_bound, count_folds) count_folds = count_folds + 1 yield train_indices, test_indices class PurgedKFoldCV(BasePurgedCV): """Purged and embargoed combinatorial cross-validation. The samples are decomposed into `n_folds` folds containing equal numbers of samples, without shuffling. In each cross validation round, `n_test_folds` folds are used as the test set, while the other folds are used as the train set. There are as many rounds as `n_test_folds` folds among the `n_folds` folds. Each sample should be tagged with a prediction time and an evaluation time. The split is such that the intervals [`pred_times`, `eval_times`] associated to samples in the train and test set do not overlap. (The overlapping samples are dropped.) In addition, an "embargo" period is defined, giving the minimal time between an evaluation time in the test set and a prediction time in the training set. This is to avoid, in the presence of temporal correlation, a contamination of the test set by the train set.""" def __init__( self, n_folds: int = 10, n_test_folds: int = 2, purge_td: tp.TimedeltaLike = 0, embargo_td: tp.TimedeltaLike = 0, ) -> None: BasePurgedCV.__init__(self, n_folds=n_folds, purge_td=purge_td) if n_test_folds > self.n_folds - 1: raise ValueError("n_test_folds must be between 1 and n_folds - 1") self._n_test_folds = n_test_folds self._embargo_td = dt.to_timedelta(embargo_td) @property def n_test_folds(self) -> int: """Number of folds used in the test set.""" return self._n_test_folds @property def embargo_td(self) -> pd.Timedelta: """Embargo period.""" return self._embargo_td def embargo( self, train_indices: tp.Array1d, test_indices: tp.Array1d, test_fold_end: int, ) -> tp.Array1d: """Apply the embargo procedure to part of the train set. This amounts to dropping the train set samples whose prediction time occurs within `PurgedKFoldCV.embargo_td` of the test set sample evaluation times. This method applies the embargo only to the part of the training set immediately following the end of the test set determined by `test_fold_end`.""" last_test_eval_time = self.eval_times.iloc[test_indices[test_indices <= test_fold_end]].max() min_train_index = len(self.pred_times[self.pred_times <= last_test_eval_time + self.embargo_td]) if min_train_index < self.indices.shape[0]: allowed_indices = np.concatenate((self.indices[:test_fold_end], self.indices[min_train_index:])) train_indices = np.intersect1d(train_indices, allowed_indices) return train_indices def compute_train_set( self, test_fold_bounds: tp.List[tp.Tuple[int, int]], test_indices: tp.Array1d, ) -> tp.Array1d: """Compute the position indices of the samples in the train set.""" train_indices = np.setdiff1d(self.indices, test_indices) for test_fold_start, test_fold_end in test_fold_bounds: train_indices = self.purge(train_indices, test_fold_start, test_fold_end) train_indices = self.embargo(train_indices, test_indices, test_fold_end) return train_indices def compute_test_set( self, fold_bound_list: tp.List[tp.Tuple[int, int]], ) -> tp.Tuple[tp.List[tp.Tuple[int, int]], tp.Array1d]: """Compute the position indices of the samples in the test set.""" test_indices = np.empty(0) test_fold_bounds = [] for fold_start, fold_end in fold_bound_list: if not test_fold_bounds or fold_start != test_fold_bounds[-1][-1]: test_fold_bounds.append((fold_start, fold_end)) elif fold_start == test_fold_bounds[-1][-1]: test_fold_bounds[-1] = (test_fold_bounds[-1][0], fold_end) test_indices = np.union1d(test_indices, self.indices[fold_start:fold_end]).astype(int) return test_fold_bounds, test_indices def split( self, X: tp.SeriesFrame, y: tp.Optional[tp.Series] = None, pred_times: tp.Union[None, tp.Index, tp.Series] = None, eval_times: tp.Union[None, tp.Index, tp.Series] = None, ) -> tp.Iterable[tp.Tuple[tp.Array1d, tp.Array1d]]: BasePurgedCV.split(self, X, y, pred_times=pred_times, eval_times=eval_times) fold_bounds = [(fold[0], fold[-1] + 1) for fold in np.array_split(self.indices, self.n_folds)] selected_fold_bounds = list(combinations(fold_bounds, self.n_test_folds)) selected_fold_bounds.reverse() for fold_bound_list in selected_fold_bounds: test_fold_bounds, test_indices = self.compute_test_set(fold_bound_list) train_indices = self.compute_train_set(test_fold_bounds, test_indices) yield train_indices, test_indices # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Scikit-learn compatible class for splitting.""" import numpy as np import pandas as pd from sklearn.model_selection import BaseCrossValidator from sklearn.utils.validation import indexable from vectorbtpro import _typing as tp from vectorbtpro.generic.splitting.base import Splitter from vectorbtpro.utils.base import Base __all__ = [ "SplitterCV", ] class SplitterCV(BaseCrossValidator, Base): """Scikit-learn compatible cross-validator based on `vectorbtpro.generic.splitting.base.Splitter`. Usage: * Replicate `TimeSeriesSplit` from scikit-learn: ```pycon >>> from vectorbtpro import * >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) >>> y = np.array([1, 2, 3, 4]) >>> cv = vbt.SplitterCV( ... "from_expanding", ... min_length=2, ... offset=1, ... split=-1 ... ) >>> for i, (train_indices, test_indices) in enumerate(cv.split(X)): ... print("Split %d:" % i) ... X_train, X_test = X[train_indices], X[test_indices] ... print(" X:", X_train.tolist(), X_test.tolist()) ... y_train, y_test = y[train_indices], y[test_indices] ... print(" y:", y_train.tolist(), y_test.tolist()) Split 0: X: [[1, 2]] [[3, 4]] y: [1] [2] Split 1: X: [[1, 2], [3, 4]] [[5, 6]] y: [1, 2] [3] Split 2: X: [[1, 2], [3, 4], [5, 6]] [[7, 8]] y: [1, 2, 3] [4] ``` """ def __init__( self, splitter: tp.Union[None, str, Splitter, tp.Callable] = None, *, splitter_cls: tp.Optional[tp.Type[Splitter]] = None, split_group_by: tp.AnyGroupByLike = None, set_group_by: tp.AnyGroupByLike = None, template_context: tp.KwargsLike = None, **splitter_kwargs, ) -> None: if splitter_cls is None: splitter_cls = Splitter if splitter is None: splitter = splitter_cls.guess_method(**splitter_kwargs) self._splitter = splitter self._splitter_kwargs = splitter_kwargs self._splitter_cls = splitter_cls self._split_group_by = split_group_by self._set_group_by = set_group_by self._template_context = template_context @property def splitter(self) -> tp.Union[str, Splitter, tp.Callable]: """Splitter. Either as a `vectorbtpro.generic.splitting.base.Splitter` instance, a factory method name, or the factory method itself. If None, will be determined automatically based on `SplitterCV.splitter_kwargs`.""" return self._splitter @property def splitter_cls(self) -> tp.Type[Splitter]: """Splitter class. Defaults to `vectorbtpro.generic.splitting.base.Splitter`.""" return self._splitter_cls @property def splitter_kwargs(self) -> tp.KwargsLike: """Keyword arguments passed to the factory method.""" return self._splitter_kwargs @property def split_group_by(self) -> tp.AnyGroupByLike: """Split groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. Not passed to the factory method.""" return self._split_group_by @property def set_group_by(self) -> tp.AnyGroupByLike: """Set groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`. Not passed to the factory method.""" return self._set_group_by @property def template_context(self) -> tp.KwargsLike: """Mapping used to substitute templates in ranges. Passed to the factory method.""" return self._template_context def get_splitter( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> Splitter: """Get splitter of type `vectorbtpro.generic.splitting.base.Splitter`.""" X, y, groups = indexable(X, y, groups) try: index = self.splitter_cls.get_obj_index(X) except ValueError as e: index = pd.RangeIndex(stop=len(X)) if isinstance(self.splitter, str): splitter = getattr(self.splitter_cls, self.splitter) else: splitter = self.splitter splitter = splitter( index, template_context=self.template_context, **self.splitter_kwargs, ) if splitter.get_n_sets(set_group_by=self.set_group_by) != 2: raise ValueError("Number of sets in the splitter must be 2: train and test") return splitter def _iter_masks( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Tuple[tp.Array1d, tp.Array1d]]: """Generates boolean masks corresponding to train and test sets.""" splitter = self.get_splitter(X=X, y=y, groups=groups) for mask_arr in splitter.get_iter_split_mask_arrs( split_group_by=self.split_group_by, set_group_by=self.set_group_by, template_context=self.template_context, ): yield mask_arr[0], mask_arr[1] def _iter_train_masks( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Array1d]: """Generates boolean masks corresponding to train sets.""" for train_mask_arr, _ in self._iter_masks(X=X, y=y, groups=groups): yield train_mask_arr def _iter_test_masks( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Array1d]: """Generates boolean masks corresponding to test sets.""" for _, test_mask_arr in self._iter_masks(X=X, y=y, groups=groups): yield test_mask_arr def _iter_indices( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Tuple[tp.Array1d, tp.Array1d]]: """Generates integer indices corresponding to train and test sets.""" for train_mask_arr, test_mask_arr in self._iter_masks(X=X, y=y, groups=groups): yield np.flatnonzero(train_mask_arr), np.flatnonzero(test_mask_arr) def _iter_train_indices( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Array1d]: """Generates integer indices corresponding to train sets.""" for train_indices, _ in self._iter_indices(X=X, y=y, groups=groups): yield train_indices def _iter_test_indices( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Array1d]: """Generates integer indices corresponding to test sets.""" for _, test_indices in self._iter_indices(X=X, y=y, groups=groups): yield test_indices def get_n_splits( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> int: """Returns the number of splitting iterations in the cross-validator.""" splitter = self.get_splitter(X=X, y=y, groups=groups) return splitter.get_n_splits(split_group_by=self.split_group_by) def split( self, X: tp.Any = None, y: tp.Any = None, groups: tp.Any = None, ) -> tp.Iterator[tp.Tuple[tp.Array1d, tp.Array1d]]: """Generate indices to split data into training and test set.""" return self._iter_indices(X=X, y=y, groups=groups) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with generic time series. In contrast to the `vectorbtpro.base` sub-package, focuses on the data itself.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.generic.nb import * from vectorbtpro.generic.splitting import * from vectorbtpro.generic.accessors import * from vectorbtpro.generic.analyzable import * from vectorbtpro.generic.decorators import * from vectorbtpro.generic.drawdowns import * from vectorbtpro.generic.plots_builder import * from vectorbtpro.generic.plotting import * from vectorbtpro.generic.price_records import * from vectorbtpro.generic.ranges import * from vectorbtpro.generic.sim_range import * from vectorbtpro.generic.stats_builder import * __exclude_from__all__ = [ "enums", ] __import_if_installed__ = dict() __import_if_installed__["plotting"] = "plotly" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom Pandas accessors for generic data. Methods can be accessed as follows: * `GenericSRAccessor` -> `pd.Series.vbt.*` * `GenericDFAccessor` -> `pd.DataFrame.vbt.*` ```pycon >>> from vectorbtpro import * >>> # vectorbtpro.generic.accessors.GenericAccessor.rolling_mean >>> pd.Series([1, 2, 3, 4]).vbt.rolling_mean(2) 0 NaN 1 1.5 2 2.5 3 3.5 dtype: float64 ``` The accessors inherit `vectorbtpro.base.accessors` and are inherited by more specialized accessors, such as `vectorbtpro.signals.accessors` and `vectorbtpro.returns.accessors`. !!! note Grouping is only supported by the methods that accept the `group_by` argument. Accessors do not utilize caching. Run for the examples below: ```pycon >>> df = pd.DataFrame({ ... 'a': [1, 2, 3, 4, 5], ... 'b': [5, 4, 3, 2, 1], ... 'c': [1, 2, 3, 2, 1] ... }, index=pd.Index(pd.date_range("2020", periods=5))) >>> df a b c 2020-01-01 1 5 1 2020-01-02 2 4 2 2020-01-03 3 3 3 2020-01-04 4 2 2 2020-01-05 5 1 1 >>> sr = pd.Series(np.arange(10), index=pd.date_range("2020", periods=10)) >>> sr 2020-01-01 0 2020-01-02 1 2020-01-03 2 2020-01-04 3 2020-01-05 4 2020-01-06 5 2020-01-07 6 2020-01-08 7 2020-01-09 8 2020-01-10 9 dtype: int64 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `GenericAccessor.metrics`. ```pycon >>> df2 = pd.DataFrame({ ... 'a': [np.nan, 2, 3], ... 'b': [4, np.nan, 5], ... 'c': [6, 7, np.nan] ... }, index=['x', 'y', 'z']) >>> df2.vbt(freq='d').stats(column='a') Start x End z Period 3 days 00:00:00 Count 2 Mean 2.5 Std 0.707107 Min 2.0 Median 2.5 Max 3.0 Min Index y Max Index z Name: a, dtype: object ``` ### Mapping Mapping can be set both in `GenericAccessor` (preferred) and `GenericAccessor.stats`: ```pycon >>> mapping = {x: 'test_' + str(x) for x in pd.unique(df2.values.flatten())} >>> df2.vbt(freq='d', mapping=mapping).stats(column='a') Start x End z Period 3 days 00:00:00 Count 2 Value Counts: test_2.0 1 Value Counts: test_3.0 1 Value Counts: test_4.0 0 Value Counts: test_5.0 0 Value Counts: test_6.0 0 Value Counts: test_7.0 0 Value Counts: test_nan 1 Name: a, dtype: object >>> df2.vbt(freq='d').stats(column='a', settings=dict(mapping=mapping)) UserWarning: Changing the mapping will create a copy of this object. Consider setting it upon object creation to re-use existing cache. Start x End z Period 3 days 00:00:00 Count 2 Value Counts: test_2.0 1 Value Counts: test_3.0 1 Value Counts: test_4.0 0 Value Counts: test_5.0 0 Value Counts: test_6.0 0 Value Counts: test_7.0 0 Value Counts: test_nan 1 Name: a, dtype: object ``` Selecting a column before calling `stats` will consider uniques from this column only: ```pycon >>> df2['a'].vbt(freq='d', mapping=mapping).stats() Start x End z Period 3 days 00:00:00 Count 2 Value Counts: test_2.0 1 Value Counts: test_3.0 1 Value Counts: test_nan 1 Name: a, dtype: object ``` To include all keys from `mapping`, pass `incl_all_keys=True`: ```pycon >>> df2['a'].vbt(freq='d', mapping=mapping).stats(settings=dict(incl_all_keys=True)) Start x End z Period 3 days 00:00:00 Count 2 Value Counts: test_2.0 1 Value Counts: test_3.0 1 Value Counts: test_4.0 0 Value Counts: test_5.0 0 Value Counts: test_6.0 0 Value Counts: test_7.0 0 Value Counts: test_nan 1 Name: a, dtype: object ``` `GenericAccessor.stats` also supports (re-)grouping: ```pycon >>> df2.vbt(freq='d').stats(column=0, group_by=[0, 0, 1]) Start x End z Period 3 days 00:00:00 Count 4 Mean 3.5 Std 1.290994 Min 2.0 Median 3.5 Max 5.0 Min Index y Max Index z Name: 0, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `GenericAccessor.subplots`. `GenericAccessor` class has a single subplot based on `GenericAccessor.plot`: ```pycon >>> df2.vbt.plots().show() ``` ![](/assets/images/api/generic_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/generic_plots.dark.svg#only-dark){: .iimg loading=lazy } """ from functools import partial import numpy as np import pandas as pd from pandas.core.resample import Resampler as PandasResampler from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro._settings import settings from vectorbtpro.base import indexes, reshaping from vectorbtpro.base.accessors import BaseAccessor, BaseDFAccessor, BaseSRAccessor from vectorbtpro.base.indexes import repeat_index from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.generic import nb from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.generic.decorators import attach_nb_methods, attach_transform_methods from vectorbtpro.generic.drawdowns import Drawdowns from vectorbtpro.generic.enums import WType, InterpMode, RescaleMode, ErrorType, DistanceMeasure from vectorbtpro.generic.plots_builder import PlotsBuilderMixin from vectorbtpro.generic.ranges import Ranges, PatternRanges from vectorbtpro.generic.stats_builder import StatsBuilderMixin from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks, chunking as ch, datetime_ as dt from vectorbtpro.utils.colors import adjust_opacity, map_value_to_cmap from vectorbtpro.utils.config import merge_dicts, resolve_dict, Config, ReadonlyConfig, HybridConfig from vectorbtpro.utils.decorators import hybrid_method, hybrid_property from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.mapping import apply_mapping, to_value_mapping from vectorbtpro.utils.template import substitute_templates from vectorbtpro.utils.warnings_ import warn try: import bottleneck as bn nanmean = bn.nanmean nanstd = bn.nanstd nansum = bn.nansum nanmax = bn.nanmax nanmin = bn.nanmin nanmedian = bn.nanmedian nanargmax = bn.nanargmax nanargmin = bn.nanargmin except ImportError: # slower numpy nanmean = np.nanmean nanstd = np.nanstd nansum = np.nansum nanmax = np.nanmax nanmin = np.nanmin nanmedian = np.nanmedian nanargmax = np.nanargmax nanargmin = np.nanargmin __all__ = [ "GenericAccessor", "GenericSRAccessor", "GenericDFAccessor", ] __pdoc__ = {} GenericAccessorT = tp.TypeVar("GenericAccessorT", bound="GenericAccessor") SplitOutputT = tp.Union[tp.MaybeTuple[tp.Tuple[tp.Frame, tp.Index]], tp.BaseFigure] class TransformerT(tp.Protocol): def __init__(self, **kwargs) -> None: ... def transform(self, *args, **kwargs) -> tp.Array2d: ... def fit_transform(self, *args, **kwargs) -> tp.Array2d: ... __pdoc__["TransformerT"] = False nb_config = ReadonlyConfig( { "shuffle": dict(func=nb.shuffle_nb, disable_chunked=True), "fillna": dict(func=nb.fillna_nb), "bshift": dict(func=nb.bshift_nb), "fshift": dict(func=nb.fshift_nb), "diff": dict(func=nb.diff_nb), "pct_change": dict(func=nb.pct_change_nb), "ffill": dict(func=nb.ffill_nb), "bfill": dict(func=nb.bfill_nb), "fbfill": dict(func=nb.fbfill_nb), "cumsum": dict(func=nb.nancumsum_nb), "cumprod": dict(func=nb.nancumprod_nb), "rolling_sum": dict(func=nb.rolling_sum_nb), "rolling_prod": dict(func=nb.rolling_prod_nb), "rolling_min": dict(func=nb.rolling_min_nb), "rolling_max": dict(func=nb.rolling_max_nb), "expanding_min": dict(func=nb.expanding_min_nb), "expanding_max": dict(func=nb.expanding_max_nb), "rolling_any": dict(func=nb.rolling_any_nb), "rolling_all": dict(func=nb.rolling_all_nb), "product": dict(func=nb.nanprod_nb, is_reducing=True), } ) """_""" __pdoc__[ "nb_config" ] = f"""Config of Numba methods to be attached to `GenericAccessor`. ```python {nb_config.prettify()} ``` """ @attach_nb_methods(nb_config) class GenericAccessor(BaseAccessor, Analyzable): """Accessor on top of data of any type. For both, Series and DataFrames. Accessible via `pd.Series.vbt` and `pd.DataFrame.vbt`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, mapping: tp.Optional[tp.MappingLike] = None, **kwargs, ) -> None: BaseAccessor.__init__(self, wrapper, obj=obj, mapping=mapping, **kwargs) StatsBuilderMixin.__init__(self) PlotsBuilderMixin.__init__(self) self._mapping = mapping @hybrid_property def sr_accessor_cls(cls_or_self) -> tp.Type["GenericSRAccessor"]: """Accessor class for `pd.Series`.""" return GenericSRAccessor @hybrid_property def df_accessor_cls(cls_or_self) -> tp.Type["GenericDFAccessor"]: """Accessor class for `pd.DataFrame`.""" return GenericDFAccessor # ############# Mapping ############# # @property def mapping(self) -> tp.Optional[tp.MappingLike]: """Mapping.""" return self._mapping def resolve_mapping(self, mapping: tp.Union[None, bool, tp.MappingLike] = None) -> tp.Optional[tp.Mapping]: """Resolve mapping. Set `mapping` to False to disable mapping completely.""" if mapping is None: mapping = self.mapping if isinstance(mapping, bool): if not mapping: return None raise ValueError("Mapping cannot be True") if isinstance(mapping, str): if mapping.lower() == "index": mapping = self.wrapper.index elif mapping.lower() == "columns": mapping = self.wrapper.columns elif mapping.lower() == "groups": mapping = self.wrapper.get_columns() mapping = to_value_mapping(mapping) return mapping def apply_mapping(self, mapping: tp.Union[None, bool, tp.MappingLike] = None, **kwargs) -> tp.SeriesFrame: """See `vectorbtpro.utils.mapping.apply_mapping`.""" mapping = self.resolve_mapping(mapping) return apply_mapping(self.obj, mapping, **kwargs) # ############# Shifting ############# # def ago( self, n: tp.Union[int, tp.FrequencyLike], fill_value: tp.Scalar = np.nan, get_indexer_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """For each value, get the value `n` periods ago.""" if checks.is_int(n): return self.fshift(n, fill_value=fill_value, **kwargs) if get_indexer_kwargs is None: get_indexer_kwargs = {} n = dt.to_timedelta(n) indices = self.wrapper.index.get_indexer(self.wrapper.index - n, **get_indexer_kwargs) new_obj = self.wrapper.fill(fill_value=fill_value) found_mask = indices != -1 new_obj.iloc[np.flatnonzero(found_mask)] = self.obj.iloc[indices[found_mask]] return new_obj def any_ago(self, n: tp.Union[int, tp.FrequencyLike], **kwargs) -> tp.SeriesFrame: """For each value, check whether any value within a window of `n` last periods is True.""" wrap_kwargs = kwargs.pop("wrap_kwargs", {}) wrap_kwargs = merge_dicts(dict(fillna=False, dtype=bool), wrap_kwargs) if checks.is_int(n): return self.rolling_any(n, wrap_kwargs=wrap_kwargs, **kwargs) return self.rolling_apply(n, "any", wrap_kwargs=wrap_kwargs, **kwargs) def all_ago(self, n: tp.Union[int, tp.FrequencyLike], **kwargs) -> tp.SeriesFrame: """For each value, check whether all values within a window of `n` last periods are True.""" wrap_kwargs = kwargs.pop("wrap_kwargs", {}) wrap_kwargs = merge_dicts(dict(fillna=False, dtype=bool), wrap_kwargs) if checks.is_int(n): return self.rolling_all(n, wrap_kwargs=wrap_kwargs, **kwargs) return self.rolling_apply(n, "all", wrap_kwargs=wrap_kwargs, **kwargs) # ############# Rolling ############# # def rolling_idxmin( self, window: tp.Optional[int], minp: tp.Optional[int] = None, local: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_argmin_nb`.""" if window is None: window = self.wrapper.shape[0] func = jit_reg.resolve_option(nb.rolling_argmin_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp, local=local) if not local: wrap_kwargs = merge_dicts(dict(to_index=True), wrap_kwargs) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_idxmin(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_idxmin`.""" return self.rolling_idxmin(None, minp=minp, **kwargs) def rolling_idxmax( self, window: tp.Optional[int], minp: tp.Optional[int] = None, local: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_argmax_nb`.""" if window is None: window = self.wrapper.shape[0] func = jit_reg.resolve_option(nb.rolling_argmax_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp, local=local) if not local: wrap_kwargs = merge_dicts(dict(to_index=True), wrap_kwargs) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_idxmax(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_idxmax`.""" return self.rolling_idxmax(None, minp=minp, **kwargs) def rolling_mean( self, window: tp.Optional[int], minp: tp.Optional[int] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_mean_nb`.""" if window is None: window = self.wrapper.shape[0] func = jit_reg.resolve_option(nb.rolling_mean_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_mean(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_mean`.""" return self.rolling_mean(None, minp=minp, **kwargs) def rolling_std( self, window: tp.Optional[int], minp: tp.Optional[int] = None, ddof: int = 1, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_std_nb`.""" if window is None: window = self.wrapper.shape[0] func = jit_reg.resolve_option(nb.rolling_std_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp, ddof=ddof) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_std(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_std`.""" return self.rolling_std(None, minp=minp, **kwargs) def rolling_zscore( self, window: tp.Optional[int], minp: tp.Optional[int] = None, ddof: int = 1, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_zscore_nb`.""" if window is None: window = self.wrapper.shape[0] func = jit_reg.resolve_option(nb.rolling_zscore_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp, ddof=ddof) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_zscore(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_zscore`.""" return self.rolling_zscore(None, minp=minp, **kwargs) def wm_mean( self, span: int, minp: tp.Optional[int] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.wm_mean_nb`.""" func = jit_reg.resolve_option(nb.wm_mean_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), span, minp=minp) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def ewm_mean( self, span: int, minp: tp.Optional[int] = 0, adjust: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.ewm_mean_nb`.""" func = jit_reg.resolve_option(nb.ewm_mean_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), span, minp=minp, adjust=adjust) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def ewm_std( self, span: int, minp: tp.Optional[int] = 0, adjust: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.ewm_std_nb`.""" func = jit_reg.resolve_option(nb.ewm_std_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), span, minp=minp, adjust=adjust) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def wwm_mean( self, period: int, minp: tp.Optional[int] = 0, adjust: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.wwm_mean_nb`.""" func = jit_reg.resolve_option(nb.wwm_mean_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), period, minp=minp, adjust=adjust) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def wwm_std( self, period: int, minp: tp.Optional[int] = 0, adjust: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.wwm_std_nb`.""" func = jit_reg.resolve_option(nb.wwm_std_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), period, minp=minp, adjust=adjust) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def vidya( self, window: int, minp: tp.Optional[int] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.vidya_nb`.""" func = jit_reg.resolve_option(nb.vidya_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def ma( self, window: int, wtype: tp.Union[int, str] = "simple", minp: tp.Optional[int] = 0, adjust: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.ma_nb`.""" if isinstance(wtype, str): wtype = map_enum_fields(wtype, WType) func = jit_reg.resolve_option(nb.ma_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, wtype=wtype, minp=minp, adjust=adjust) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def msd( self, window: int, wtype: tp.Union[int, str] = "simple", minp: tp.Optional[int] = 0, adjust: bool = True, ddof: int = 1, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.msd_nb`.""" if isinstance(wtype, str): wtype = map_enum_fields(wtype, WType) func = jit_reg.resolve_option(nb.msd_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def rolling_cov( self, other: tp.SeriesFrame, window: tp.Optional[int], minp: tp.Optional[int] = None, ddof: int = 1, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_cov_nb`.""" self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs)) if window is None: window = self_obj.shape[0] func = jit_reg.resolve_option(nb.rolling_cov_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(reshaping.to_2d_array(self_obj), reshaping.to_2d_array(other_obj), window, minp=minp, ddof=ddof) return ArrayWrapper.from_obj(self_obj).wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_cov(self, other: tp.SeriesFrame, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_cov`.""" return self.rolling_cov(other, None, minp=minp, **kwargs) def rolling_corr( self, other: tp.SeriesFrame, window: tp.Optional[int], minp: tp.Optional[int] = None, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_corr_nb`.""" self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs)) if window is None: window = self_obj.shape[0] func = jit_reg.resolve_option(nb.rolling_corr_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(reshaping.to_2d_array(self_obj), reshaping.to_2d_array(other_obj), window, minp=minp) return ArrayWrapper.from_obj(self_obj).wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_corr(self, other: tp.SeriesFrame, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_corr`.""" return self.rolling_corr(other, None, minp=minp, **kwargs) def rolling_ols( self, other: tp.SeriesFrame, window: tp.Optional[int], minp: tp.Optional[int] = None, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]: """See `vectorbtpro.generic.nb.rolling.rolling_ols_nb`. Returns two arrays: slope and intercept.""" self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs)) if window is None: window = self_obj.shape[0] func = jit_reg.resolve_option(nb.rolling_ols_nb, jitted) func = ch_reg.resolve_option(func, chunked) slope_out, intercept_out = func( reshaping.to_2d_array(self_obj), reshaping.to_2d_array(other_obj), window, minp=minp, ) return ( ArrayWrapper.from_obj(self_obj).wrap(slope_out, group_by=False, **resolve_dict(wrap_kwargs)), ArrayWrapper.from_obj(self_obj).wrap(intercept_out, group_by=False, **resolve_dict(wrap_kwargs)), ) def expanding_ols( self, other: tp.SeriesFrame, minp: tp.Optional[int] = 1, **kwargs, ) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]: """Expanding version of `GenericAccessor.rolling_ols`.""" return self.rolling_ols(other, None, minp=minp, **kwargs) def rolling_rank( self, window: tp.Optional[int], minp: tp.Optional[int] = None, pct: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_rank_nb`.""" if window is None: window = self.wrapper.shape[0] func = jit_reg.resolve_option(nb.rolling_rank_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array(), window, minp=minp, pct=pct) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def expanding_rank(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame: """Expanding version of `GenericAccessor.rolling_rank`.""" return self.rolling_rank(None, minp=minp, **kwargs) def rolling_pattern_similarity( self, pattern: tp.ArrayLike, window: tp.Optional[int] = None, max_window: tp.Optional[int] = None, row_select_prob: float = 1.0, window_select_prob: float = 1.0, interp_mode: tp.Union[int, str] = "mixed", rescale_mode: tp.Union[int, str] = "minmax", vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: tp.Union[int, str] = "absolute", distance_measure: tp.Union[int, str] = "mae", max_error: tp.ArrayLike = np.nan, max_error_interp_mode: tp.Union[None, int, str] = None, max_error_as_maxdist: bool = False, max_error_strict: bool = False, min_pct_change: float = np.nan, max_pct_change: float = np.nan, min_similarity: float = np.nan, minp: tp.Optional[int] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.rolling.rolling_pattern_similarity_nb`.""" if isinstance(interp_mode, str): interp_mode = map_enum_fields(interp_mode, InterpMode) if isinstance(rescale_mode, str): rescale_mode = map_enum_fields(rescale_mode, RescaleMode) if isinstance(error_type, str): error_type = map_enum_fields(error_type, ErrorType) if isinstance(distance_measure, str): distance_measure = map_enum_fields(distance_measure, DistanceMeasure) if max_error_interp_mode is not None and isinstance(max_error_interp_mode, str): max_error_interp_mode = map_enum_fields(max_error_interp_mode, InterpMode) if max_error_interp_mode is None: max_error_interp_mode = interp_mode func = jit_reg.resolve_option(nb.rolling_pattern_similarity_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), reshaping.to_1d_array(pattern), window=window, max_window=max_window, row_select_prob=row_select_prob, window_select_prob=window_select_prob, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=vmin, vmax=vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, distance_measure=distance_measure, max_error=reshaping.to_1d_array(max_error), max_error_interp_mode=max_error_interp_mode, max_error_as_maxdist=max_error_as_maxdist, max_error_strict=max_error_strict, min_pct_change=min_pct_change, max_pct_change=max_pct_change, min_similarity=min_similarity, minp=minp, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Mapping ############# # @hybrid_method def map( cls_or_self, map_func_nb: tp.Union[str, tp.MapFunc, tp.MapMetaFunc], *args, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.apply_reduce.map_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.map_meta_nb`. Usage: * Using regular function: ```pycon >>> prod_nb = njit(lambda a, x: a * x) >>> df.vbt.map(prod_nb, 10) a b c 2020-01-01 10 50 10 2020-01-02 20 40 20 2020-01-03 30 30 30 2020-01-04 40 20 20 2020-01-05 50 10 10 ``` * Using meta function: ```pycon >>> diff_meta_nb = njit(lambda i, col, a, b: a[i, col] / b[i, col]) >>> vbt.pd_acc.map( ... diff_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper ... ) a b c 2020-01-01 0.000000 0.666667 0.000000 2020-01-02 0.333333 0.600000 0.333333 2020-01-03 0.500000 0.500000 0.500000 2020-01-04 0.600000 0.333333 0.333333 2020-01-05 0.666667 0.000000 0.000000 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.map( ... diff_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) a b c 2020-01-01 1.0 0.5 0.333333 2020-01-02 2.0 1.0 0.666667 2020-01-03 3.0 1.5 1.000000 2020-01-04 4.0 2.0 1.333333 2020-01-05 5.0 2.5 1.666667 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(map_func_nb, str): map_func_nb = getattr(nb, map_func_nb + "_map_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) args = substitute_templates(args, template_context, eval_id="args") func = jit_reg.resolve_option(nb.map_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d, map_func_nb, *args) else: func = jit_reg.resolve_option(nb.map_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), map_func_nb, *args) if wrapper is None: wrapper = cls_or_self.wrapper return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Applying ############# # @hybrid_method def apply_along_axis( cls_or_self, apply_func_nb: tp.Union[str, tp.ApplyFunc, tp.ApplyMetaFunc], *args, axis: int = 1, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.apply_reduce.apply_nb` for `axis=1` and `vectorbtpro.generic.nb.apply_reduce.row_apply_nb` for `axis=0`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.apply_meta_nb` for `axis=1` and `vectorbtpro.generic.nb.apply_reduce.row_apply_meta_nb` for `axis=0`. Usage: * Using regular function: ```pycon >>> power_nb = njit(lambda a: np.power(a, 2)) >>> df.vbt.apply_along_axis(power_nb) a b c 2020-01-01 1 25 1 2020-01-02 4 16 4 2020-01-03 9 9 9 2020-01-04 16 4 4 2020-01-05 25 1 1 ``` * Using meta function: ```pycon >>> ratio_meta_nb = njit(lambda col, a, b: a[:, col] / b[:, col]) >>> vbt.pd_acc.apply_along_axis( ... ratio_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper ... ) a b c 2020-01-01 0.000000 0.666667 0.000000 2020-01-02 0.333333 0.600000 0.333333 2020-01-03 0.500000 0.500000 0.500000 2020-01-04 0.600000 0.333333 0.333333 2020-01-05 0.666667 0.000000 0.000000 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.apply_along_axis( ... ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) a b c 2020-01-01 1.0 0.5 0.333333 2020-01-02 2.0 1.0 0.666667 2020-01-03 3.0 1.5 1.000000 2020-01-04 4.0 2.0 1.333333 2020-01-05 5.0 2.5 1.666667 ``` """ checks.assert_in(axis, (0, 1)) if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(apply_func_nb, str): apply_func_nb = getattr(nb, apply_func_nb + "_apply_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper, axis=axis), template_context) args = substitute_templates(args, template_context, eval_id="args") if axis == 0: func = jit_reg.resolve_option(nb.row_apply_meta_nb, jitted) else: func = jit_reg.resolve_option(nb.apply_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d, apply_func_nb, *args) else: if axis == 0: func = jit_reg.resolve_option(nb.row_apply_nb, jitted) else: func = jit_reg.resolve_option(nb.apply_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), apply_func_nb, *args) if wrapper is None: wrapper = cls_or_self.wrapper return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def row_apply(self, *args, **kwargs) -> tp.SeriesFrame: """`GenericAccessor.apply_along_axis` with `axis=0`.""" return self.apply_along_axis(*args, axis=0, **kwargs) @hybrid_method def column_apply(self, *args, **kwargs) -> tp.SeriesFrame: """`GenericAccessor.apply_along_axis` with `axis=1`.""" return self.apply_along_axis(*args, axis=1, **kwargs) # ############# Reducing ############# # @hybrid_method def rolling_apply( cls_or_self, window: tp.Optional[tp.FrequencyLike], reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.RangeReduceMetaFunc], *args, minp: tp.Optional[int] = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.apply_reduce.rolling_reduce_nb` for integer windows and `vectorbtpro.generic.nb.apply_reduce.rolling_freq_reduce_nb` for frequency windows. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.rolling_reduce_meta_nb` for integer windows and `vectorbtpro.generic.nb.apply_reduce.rolling_freq_reduce_meta_nb` for frequency windows. If `window` is None, it will become an expanding window. Usage: * Using regular function: ```pycon >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> df.vbt.rolling_apply(3, mean_nb) a b c 2020-01-01 NaN NaN NaN 2020-01-02 NaN NaN NaN 2020-01-03 2.0 4.0 2.000000 2020-01-04 3.0 3.0 2.333333 2020-01-05 4.0 2.0 2.000000 ``` * Using a frequency-based window: ```pycon >>> df.vbt.rolling_apply("3d", mean_nb) a b c 2020-01-01 1.0 5.0 1.000000 2020-01-02 1.5 4.5 1.500000 2020-01-03 2.0 4.0 2.000000 2020-01-04 3.0 3.0 2.333333 2020-01-05 4.0 2.0 2.000000 ``` * Using meta function: ```pycon >>> mean_ratio_meta_nb = njit(lambda from_i, to_i, col, a, b: \\ ... np.mean(a[from_i:to_i, col]) / np.mean(b[from_i:to_i, col])) >>> vbt.pd_acc.rolling_apply( ... 3, ... mean_ratio_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper, ... ) a b c 2020-01-01 NaN NaN NaN 2020-01-02 NaN NaN NaN 2020-01-03 0.333333 0.600000 0.333333 2020-01-04 0.500000 0.500000 0.400000 2020-01-05 0.600000 0.333333 0.333333 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.rolling_apply( ... 2, ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) a b c 2020-01-01 NaN NaN NaN 2020-01-02 1.5 0.75 0.500000 2020-01-03 2.5 1.25 0.833333 2020-01-04 3.5 1.75 1.166667 2020-01-05 4.5 2.25 1.500000 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") else: if wrapper is None: wrapper = cls_or_self.wrapper if window is not None: if not isinstance(window, int): window = dt.to_timedelta64(window) if minp is None and window is None: minp = 1 if window is None: window = wrapper.shape[0] if minp is None: minp = window if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): template_context = merge_dicts( broadcast_named_args, dict(wrapper=wrapper, window=window, minp=minp), template_context, ) args = substitute_templates(args, template_context, eval_id="args") if isinstance(window, int): func = jit_reg.resolve_option(nb.rolling_reduce_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d, window, minp, reduce_func_nb, *args) else: func = jit_reg.resolve_option(nb.rolling_freq_reduce_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d[1], wrapper.index.values, window, reduce_func_nb, *args) else: if isinstance(window, int): func = jit_reg.resolve_option(nb.rolling_reduce_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), window, minp, reduce_func_nb, *args) else: func = jit_reg.resolve_option(nb.rolling_freq_reduce_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.index.values, cls_or_self.to_2d_array(), window, reduce_func_nb, *args) return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def expanding_apply(cls_or_self, *args, **kwargs) -> tp.SeriesFrame: """`GenericAccessor.rolling_apply` but expanding.""" return cls_or_self.rolling_apply(None, *args, **kwargs) @hybrid_method def groupby_apply( cls_or_self, by: tp.AnyGroupByLike, reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.GroupByReduceMetaFunc], *args, groupby_kwargs: tp.KwargsLike = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.apply_reduce.groupby_reduce_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.groupby_reduce_meta_nb`. Argument `by` can be an instance of `vectorbtpro.base.grouping.base.Grouper`, `pandas.core.groupby.GroupBy`, `pandas.core.resample.Resampler`, or any other groupby-like object that can be accepted by `vectorbtpro.base.grouping.base.Grouper`, or if it fails, then by `pd.DataFrame.groupby` with `groupby_kwargs` passed as keyword arguments. Usage: * Using regular function: ```pycon >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> df.vbt.groupby_apply([1, 1, 2, 2, 3], mean_nb) a b c 1 1.5 4.5 1.5 2 3.5 2.5 2.5 3 5.0 1.0 1.0 ``` * Using meta function: ```pycon >>> mean_ratio_meta_nb = njit(lambda idxs, group, col, a, b: \\ ... np.mean(a[idxs, col]) / np.mean(b[idxs, col])) >>> vbt.pd_acc.groupby_apply( ... [1, 1, 2, 2, 3], ... mean_ratio_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper ... ) a b c 1 0.200000 0.636364 0.200000 2 0.555556 0.428571 0.428571 3 0.666667 0.000000 0.000000 ``` * Using templates and broadcasting, let's split both input arrays into 2 groups of rows and run the calculation function on each group: ```pycon >>> from vectorbtpro.base.grouping.nb import group_by_evenly_nb >>> vbt.pd_acc.groupby_apply( ... vbt.RepEval('group_by_evenly_nb(wrapper.shape[0], 2)'), ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ), ... template_context=dict(group_by_evenly_nb=group_by_evenly_nb) ... ) a b c 0 2.0 1.00 0.666667 1 4.5 2.25 1.500000 ``` The advantage of the approach above is in the flexibility: we can pass two arrays of any broadcastable shapes and everything else is done for us. """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) by = substitute_templates(by, template_context, eval_id="by") else: if wrapper is None: wrapper = cls_or_self.wrapper grouper = wrapper.get_index_grouper(by, **resolve_dict(groupby_kwargs)) if isinstance(cls_or_self, type): group_map = grouper.get_group_map() template_context = merge_dicts(dict(by=by, grouper=grouper), template_context) args = substitute_templates(args, template_context, eval_id="args") func = jit_reg.resolve_option(nb.groupby_reduce_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d[1], group_map, reduce_func_nb, *args) else: group_map = grouper.get_group_map() func = jit_reg.resolve_option(nb.groupby_reduce_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), group_map, reduce_func_nb, *args) wrap_kwargs = merge_dicts(dict(name_or_index=grouper.get_index()), wrap_kwargs) return wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) @hybrid_method def groupby_transform( cls_or_self, by: tp.AnyGroupByLike, transform_func_nb: tp.Union[str, tp.GroupByTransformFunc, tp.GroupByTransformMetaFunc], *args, groupby_kwargs: tp.KwargsLike = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.apply_reduce.groupby_transform_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.groupby_transform_meta_nb`. For argument `by`, see `GenericAccessor.groupby_apply`. Usage: * Using regular function: ```pycon >>> zscore_nb = njit(lambda a: (a - np.nanmean(a)) / np.nanstd(a)) >>> df.vbt.groupby_transform([1, 1, 2, 2, 3], zscore_nb) a b c 2020-01-01 -1.000000 1.666667 -1.000000 2020-01-02 -0.333333 1.000000 -0.333333 2020-01-03 0.242536 0.242536 0.242536 2020-01-04 1.697749 -1.212678 -1.212678 2020-01-05 1.414214 -0.707107 -0.707107 ``` * Using meta function: ```pycon >>> zscore_ratio_meta_nb = njit(lambda idxs, group, a, b: \\ ... zscore_nb(a[idxs]) / zscore_nb(b[idxs])) >>> vbt.pd_acc.groupby_transform( ... [1, 1, 2, 2, 3], ... zscore_ratio_meta_nb, ... df.vbt.to_2d_array(), ... df.vbt.to_2d_array()[::-1], ... wrapper=df.vbt.wrapper ... ) a b c 2020-01-01 -0.600000 -1.666667 1.0 2020-01-02 -0.333333 -3.000000 1.0 2020-01-03 1.000000 1.000000 1.0 2020-01-04 -1.400000 -0.714286 1.0 2020-01-05 -2.000000 -0.500000 1.0 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if wrap_kwargs is None: wrap_kwargs = {} if isinstance(transform_func_nb, str): transform_func_nb = getattr(nb, transform_func_nb + "_transform_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) by = substitute_templates(by, template_context, eval_id="by") else: if wrapper is None: wrapper = cls_or_self.wrapper grouper = wrapper.get_index_grouper(by, **resolve_dict(groupby_kwargs)) if isinstance(cls_or_self, type): group_map = grouper.get_group_map() template_context = merge_dicts(dict(by=by, grouper=grouper), template_context) args = substitute_templates(args, template_context, eval_id="args") func = jit_reg.resolve_option(nb.groupby_transform_meta_nb, jitted) out = func(wrapper.shape_2d, group_map, transform_func_nb, *args) else: group_map = grouper.get_group_map() func = jit_reg.resolve_option(nb.groupby_transform_nb, jitted) out = func(cls_or_self.to_2d_array(), group_map, transform_func_nb, *args) return wrapper.wrap(out, group_by=False, **wrap_kwargs) @hybrid_method def resample_apply( cls_or_self, rule: tp.AnyRuleLike, reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.GroupByReduceMetaFunc, tp.RangeReduceMetaFunc], *args, use_groupby_apply: bool = False, freq: tp.Optional[tp.FrequencyLike] = None, resample_kwargs: tp.KwargsLike = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Resample. Argument `rule` can be an instance of `vectorbtpro.base.resampling.base.Resampler`, `pandas.core.resample.Resampler`, or any other frequency-like object that can be accepted by `pd.DataFrame.resample` with `resample_kwargs` passed as keyword arguments. If `use_groupby_apply` is True, uses `GenericAccessor.groupby_apply` (with some post-processing). Otherwise, uses `GenericAccessor.resample_to_index`. Usage: * Using regular function: ```pycon >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> df.vbt.resample_apply('2d', mean_nb) a b c 2020-01-01 1.5 4.5 1.5 2020-01-03 3.5 2.5 2.5 2020-01-05 5.0 1.0 1.0 ``` * Using meta function: ```pycon >>> mean_ratio_meta_nb = njit(lambda idxs, group, col, a, b: \\ ... np.mean(a[idxs, col]) / np.mean(b[idxs, col])) >>> vbt.pd_acc.resample_apply( ... '2d', ... mean_ratio_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper ... ) a b c 2020-01-01 0.200000 0.636364 0.200000 2020-01-03 0.555556 0.428571 0.428571 2020-01-05 0.666667 0.000000 0.000000 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.resample_apply( ... '2d', ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) a b c 2020-01-01 1.5 0.75 0.500000 2020-01-03 3.5 1.75 1.166667 2020-01-05 5.0 2.50 1.666667 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) rule = substitute_templates(rule, template_context, eval_id="rule") else: if wrapper is None: wrapper = cls_or_self.wrapper if use_groupby_apply: if isinstance(rule, Resampler): raise TypeError("Resampler cannot be used with use_groupby_apply=True") if not isinstance(rule, PandasResampler): rule = pd.Series(index=wrapper.index, dtype=object).resample(rule, **resolve_dict(resample_kwargs)) out_obj = cls_or_self.groupby_apply( rule, reduce_func_nb, *args, template_context=template_context, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ) new_index = rule.count().index.rename("group") if pd.Index.equals(out_obj.index, new_index): if new_index.freq is not None: try: out_obj.index.freq = new_index.freq except ValueError as e: pass return out_obj resampled_arr = np.full((rule.ngroups, wrapper.shape_2d[1]), np.nan) resampled_obj = wrapper.wrap( resampled_arr, index=new_index, **resolve_dict(wrap_kwargs), ) resampled_obj.loc[out_obj.index] = out_obj.values return resampled_obj if not isinstance(rule, Resampler): rule = wrapper.get_resampler( rule, freq=freq, resample_kwargs=resample_kwargs, return_pd_resampler=False, ) return cls_or_self.resample_to_index( rule, reduce_func_nb, *args, template_context=template_context, wrapper=wrapper, wrap_kwargs=wrap_kwargs, **kwargs, ) @hybrid_method def apply_and_reduce( cls_or_self, apply_func_nb: tp.Union[str, tp.ApplyFunc, tp.ApplyMetaFunc], reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.ReduceMetaFunc], apply_args: tp.Optional[tuple] = None, reduce_args: tp.Optional[tuple] = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """See `vectorbtpro.generic.nb.apply_reduce.apply_and_reduce_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.apply_and_reduce_meta_nb`. Usage: * Using regular function: ```pycon >>> greater_nb = njit(lambda a: a[a > 2]) >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> df.vbt.apply_and_reduce(greater_nb, mean_nb) a 4.0 b 4.0 c 3.0 Name: apply_and_reduce, dtype: float64 ``` * Using meta function: ```pycon >>> and_meta_nb = njit(lambda col, a, b: a[:, col] & b[:, col]) >>> sum_meta_nb = njit(lambda col, x: np.sum(x)) >>> vbt.pd_acc.apply_and_reduce( ... and_meta_nb, ... sum_meta_nb, ... apply_args=( ... df.vbt.to_2d_array() > 1, ... df.vbt.to_2d_array() < 4 ... ), ... wrapper=df.vbt.wrapper ... ) a 2 b 2 c 3 Name: apply_and_reduce, dtype: int64 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.apply_and_reduce( ... and_meta_nb, ... sum_meta_nb, ... apply_args=( ... vbt.Rep('mask_a'), ... vbt.Rep('mask_b') ... ), ... broadcast_named_args=dict( ... mask_a=pd.Series([True, True, True, False, False], index=df.index), ... mask_b=pd.DataFrame([[True, True, False]], columns=['a', 'b', 'c']) ... ) ... ) a 3 b 3 c 0 Name: apply_and_reduce, dtype: int64 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(apply_func_nb, str): apply_func_nb = getattr(nb, apply_func_nb + "_apply_nb") if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if apply_args is None: apply_args = () if reduce_args is None: reduce_args = () if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) apply_args = substitute_templates(apply_args, template_context, eval_id="apply_args") reduce_args = substitute_templates(reduce_args, template_context, eval_id="reduce_args") func = jit_reg.resolve_option(nb.apply_and_reduce_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d[1], apply_func_nb, apply_args, reduce_func_nb, reduce_args) else: func = jit_reg.resolve_option(nb.apply_and_reduce_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), apply_func_nb, apply_args, reduce_func_nb, reduce_args) if wrapper is None: wrapper = cls_or_self.wrapper wrap_kwargs = merge_dicts(dict(name_or_index="apply_and_reduce"), wrap_kwargs) return wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) @hybrid_method def reduce( cls_or_self, reduce_func_nb: tp.Union[ str, tp.ReduceFunc, tp.ReduceMetaFunc, tp.ReduceToArrayFunc, tp.ReduceToArrayMetaFunc, tp.ReduceGroupedFunc, tp.ReduceGroupedMetaFunc, tp.ReduceGroupedToArrayFunc, tp.ReduceGroupedToArrayMetaFunc, ], *args, returns_array: bool = False, returns_idx: bool = False, flatten: bool = False, order: str = "C", to_index: bool = True, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeriesFrame: """Reduce by column/group. Set `flatten` to True when working with grouped data to pass a flattened array to `reduce_func_nb`. The order in which to flatten the array can be specified using `order`. Set `returns_array` to True if `reduce_func_nb` returns an array. Set `returns_idx` to True if `reduce_func_nb` returns row index/position. Set `to_index` to True to return labels instead of positions. For implementation details, see * `vectorbtpro.generic.nb.apply_reduce.reduce_flat_grouped_to_array_nb` if grouped, `returns_array` is True, and `flatten` is True * `vectorbtpro.generic.nb.apply_reduce.reduce_flat_grouped_nb` if grouped, `returns_array` is False, and `flatten` is True * `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_to_array_nb` if grouped, `returns_array` is True, and `flatten` is False * `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_nb` if grouped, `returns_array` is False, and `flatten` is False * `vectorbtpro.generic.nb.apply_reduce.reduce_to_array_nb` if not grouped and `returns_array` is True * `vectorbtpro.generic.nb.apply_reduce.reduce_nb` if not grouped and `returns_array` is False For implementation details on the meta versions, see * `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_to_array_meta_nb` if grouped and `returns_array` is True * `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_meta_nb` if grouped and `returns_array` is False * `vectorbtpro.generic.nb.apply_reduce.reduce_to_array_meta_nb` if not grouped and `returns_array` is True * `vectorbtpro.generic.nb.apply_reduce.reduce_meta_nb` if not grouped and `returns_array` is False `reduce_func_nb` can be a string denoting the suffix of a reducing function from `vectorbtpro.generic.nb`. For example, "sum" will refer to "sum_reduce_nb". Usage: * Using regular function: ```pycon >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> df.vbt.reduce(mean_nb) a 3.0 b 3.0 c 1.8 Name: reduce, dtype: float64 >>> argmax_nb = njit(lambda a: np.argmax(a)) >>> df.vbt.reduce(argmax_nb, returns_idx=True) a 2020-01-05 b 2020-01-01 c 2020-01-03 Name: reduce, dtype: datetime64[ns] >>> df.vbt.reduce(argmax_nb, returns_idx=True, to_index=False) a 4 b 0 c 2 Name: reduce, dtype: int64 >>> min_max_nb = njit(lambda a: np.array([np.nanmin(a), np.nanmax(a)])) >>> df.vbt.reduce(min_max_nb, returns_array=True, wrap_kwargs=dict(name_or_index=['min', 'max'])) a b c min 1 1 1 max 5 5 3 >>> group_by = pd.Series(['first', 'first', 'second'], name='group') >>> df.vbt.reduce(mean_nb, group_by=group_by) group first 3.0 second 1.8 dtype: float64 ``` * Using meta function: ```pycon >>> mean_meta_nb = njit(lambda col, a: np.nanmean(a[:, col])) >>> pd.Series.vbt.reduce( ... mean_meta_nb, ... df['a'].vbt.to_2d_array(), ... wrapper=df['a'].vbt.wrapper ... ) 3.0 >>> vbt.pd_acc.reduce( ... mean_meta_nb, ... df.vbt.to_2d_array(), ... wrapper=df.vbt.wrapper ... ) a 3.0 b 3.0 c 1.8 Name: reduce, dtype: float64 >>> grouped_mean_meta_nb = njit(lambda group_idxs, group, a: np.nanmean(a[:, group_idxs])) >>> group_by = pd.Series(['first', 'first', 'second'], name='group') >>> vbt.pd_acc.reduce( ... grouped_mean_meta_nb, ... df.vbt.to_2d_array(), ... wrapper=df.vbt.wrapper, ... group_by=group_by ... ) group first 3.0 second 1.8 Name: reduce, dtype: float64 ``` * Using templates and broadcasting: ```pycon >>> mean_a_b_nb = njit(lambda col, a, b: \\ ... np.array([np.nanmean(a[:, col]), np.nanmean(b[:, col])])) >>> vbt.pd_acc.reduce( ... mean_a_b_nb, ... vbt.Rep('arr1'), ... vbt.Rep('arr2'), ... returns_array=True, ... broadcast_named_args=dict( ... arr1=pd.Series([1, 2, 3, 4, 5], index=df.index), ... arr2=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ), ... wrap_kwargs=dict(name_or_index=['arr1', 'arr2']) ... ) a b c arr1 3.0 3.0 3.0 arr2 1.0 2.0 3.0 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts( broadcast_named_args, dict( wrapper=wrapper, group_by=group_by, returns_array=returns_array, returns_idx=returns_idx, flatten=flatten, order=order, ), template_context, ) args = substitute_templates(args, template_context, eval_id="args") if wrapper.grouper.is_grouped(group_by=group_by): group_map = wrapper.grouper.get_group_map(group_by=group_by) if returns_array: func = jit_reg.resolve_option(nb.reduce_grouped_to_array_meta_nb, jitted) else: func = jit_reg.resolve_option(nb.reduce_grouped_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(group_map, reduce_func_nb, *args) else: if returns_array: func = jit_reg.resolve_option(nb.reduce_to_array_meta_nb, jitted) else: func = jit_reg.resolve_option(nb.reduce_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d[1], reduce_func_nb, *args) else: if wrapper is None: wrapper = cls_or_self.wrapper if wrapper.grouper.is_grouped(group_by=group_by): group_map = wrapper.grouper.get_group_map(group_by=group_by) if flatten: checks.assert_in(order.upper(), ["C", "F"]) in_c_order = order.upper() == "C" if returns_array: func = jit_reg.resolve_option(nb.reduce_flat_grouped_to_array_nb, jitted) else: func = jit_reg.resolve_option(nb.reduce_flat_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), group_map, in_c_order, reduce_func_nb, *args) if returns_idx: if in_c_order: out //= group_map[1] # flattened in C order else: out %= wrapper.shape[0] # flattened in F order else: if returns_array: func = jit_reg.resolve_option(nb.reduce_grouped_to_array_nb, jitted) else: func = jit_reg.resolve_option(nb.reduce_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), group_map, reduce_func_nb, *args) else: if returns_array: func = jit_reg.resolve_option(nb.reduce_to_array_nb, jitted) else: func = jit_reg.resolve_option(nb.reduce_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), reduce_func_nb, *args) wrap_kwargs = merge_dicts( dict( name_or_index="reduce" if not returns_array else None, to_index=returns_idx and to_index, fillna=-1 if returns_idx else None, dtype=int_ if returns_idx else None, ), wrap_kwargs, ) return wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs) @hybrid_method def proximity_apply( cls_or_self, window: int, reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.ProximityReduceMetaFunc], *args, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Frame: """See `vectorbtpro.generic.nb.apply_reduce.proximity_reduce_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.proximity_reduce_meta_nb`. Usage: * Using regular function: ```pycon >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> df.vbt.proximity_apply(1, mean_nb) a b c 2020-01-01 3.0 2.500000 3.000000 2020-01-02 3.0 2.666667 3.000000 2020-01-03 3.0 2.777778 2.666667 2020-01-04 3.0 2.666667 2.000000 2020-01-05 3.0 2.500000 1.500000 ``` * Using meta function: ```pycon >>> @njit ... def mean_ratio_meta_nb(from_i, to_i, from_col, to_col, a, b): ... a_mean = np.mean(a[from_i:to_i, from_col:to_col]) ... b_mean = np.mean(b[from_i:to_i, from_col:to_col]) ... return a_mean / b_mean >>> vbt.pd_acc.proximity_apply( ... 1, ... mean_ratio_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper, ... ) a b c 2020-01-01 0.5 0.428571 0.500000 2020-01-02 0.5 0.454545 0.500000 2020-01-03 0.5 0.470588 0.454545 2020-01-04 0.5 0.454545 0.333333 2020-01-05 0.5 0.428571 0.200000 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.proximity_apply( ... 1, ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ) ... ) a b c 2020-01-01 1.000000 0.75 0.6 2020-01-02 1.333333 1.00 0.8 2020-01-03 2.000000 1.50 1.2 2020-01-04 2.666667 2.00 1.6 2020-01-05 3.000000 2.25 1.8 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") else: if wrapper is None: wrapper = cls_or_self.wrapper if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): template_context = merge_dicts( broadcast_named_args, dict(wrapper=wrapper, window=window), template_context, ) args = substitute_templates(args, template_context, eval_id="args") func = jit_reg.resolve_option(nb.proximity_reduce_meta_nb, jitted) out = func(wrapper.shape_2d, window, reduce_func_nb, *args) else: func = jit_reg.resolve_option(nb.proximity_reduce_nb, jitted) out = func(cls_or_self.to_2d_array(), window, reduce_func_nb, *args) return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Squeezing ############# # @hybrid_method def squeeze_grouped( cls_or_self, squeeze_func_nb: tp.Union[str, tp.ReduceFunc, tp.GroupSqueezeMetaFunc], *args, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Squeeze each group of columns into a single column. See `vectorbtpro.generic.nb.apply_reduce.squeeze_grouped_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.squeeze_grouped_meta_nb`. Usage: * Using regular function: ```pycon >>> mean_nb = njit(lambda a: np.nanmean(a)) >>> group_by = pd.Series(['first', 'first', 'second'], name='group') >>> df.vbt.squeeze_grouped(mean_nb, group_by=group_by) group first second 2020-01-01 3.0 1.0 2020-01-02 3.0 2.0 2020-01-03 3.0 3.0 2020-01-04 3.0 2.0 2020-01-05 3.0 1.0 ``` * Using meta function: ```pycon >>> mean_ratio_meta_nb = njit(lambda i, group_idxs, group, a, b: \\ ... np.mean(a[i][group_idxs]) / np.mean(b[i][group_idxs])) >>> vbt.pd_acc.squeeze_grouped( ... mean_ratio_meta_nb, ... df.vbt.to_2d_array() - 1, ... df.vbt.to_2d_array() + 1, ... wrapper=df.vbt.wrapper, ... group_by=group_by ... ) group first second 2020-01-01 0.5 0.000000 2020-01-02 0.5 0.333333 2020-01-03 0.5 0.500000 2020-01-04 0.5 0.333333 2020-01-05 0.5 0.000000 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.squeeze_grouped( ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=pd.Series([1, 2, 3, 4, 5], index=df.index), ... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c']) ... ), ... group_by=[0, 0, 1] ... ) 0 1 2020-01-01 0.666667 0.333333 2020-01-02 1.333333 0.666667 2020-01-03 2.000000 1.000000 2020-01-04 2.666667 1.333333 2020-01-05 3.333333 1.666667 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(squeeze_func_nb, str): squeeze_func_nb = getattr(nb, squeeze_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts( broadcast_named_args, dict(wrapper=wrapper, group_by=group_by), template_context, ) args = substitute_templates(args, template_context, eval_id="args") if not wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping required") group_map = wrapper.grouper.get_group_map(group_by=group_by) func = jit_reg.resolve_option(nb.squeeze_grouped_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(wrapper.shape_2d[0], group_map, squeeze_func_nb, *args) else: if wrapper is None: wrapper = cls_or_self.wrapper if not wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping required") group_map = wrapper.grouper.get_group_map(group_by=group_by) func = jit_reg.resolve_option(nb.squeeze_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.to_2d_array(), group_map, squeeze_func_nb, *args) return wrapper.wrap(out, group_by=group_by, **resolve_dict(wrap_kwargs)) # ############# Flattening ############# # def flatten_grouped( self, order: str = "C", jitted: tp.JittedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Flatten each group of columns. See `vectorbtpro.generic.nb.apply_reduce.flatten_grouped_nb`. If all groups have the same length, see `vectorbtpro.generic.nb.apply_reduce.flatten_uniform_grouped_nb`. !!! warning Make sure that the distribution of group lengths is close to uniform, otherwise groups with less columns will be filled with NaN and needlessly occupy memory. Usage: ```pycon >>> group_by = pd.Series(['first', 'first', 'second'], name='group') >>> df.vbt.flatten_grouped(group_by=group_by, order='C') group first second 2020-01-01 1.0 1.0 2020-01-01 5.0 NaN 2020-01-02 2.0 2.0 2020-01-02 4.0 NaN 2020-01-03 3.0 3.0 2020-01-03 3.0 NaN 2020-01-04 4.0 2.0 2020-01-04 2.0 NaN 2020-01-05 5.0 1.0 2020-01-05 1.0 NaN >>> df.vbt.flatten_grouped(group_by=group_by, order='F') group first second 2020-01-01 1.0 1.0 2020-01-02 2.0 2.0 2020-01-03 3.0 3.0 2020-01-04 4.0 2.0 2020-01-05 5.0 1.0 2020-01-01 5.0 NaN 2020-01-02 4.0 NaN 2020-01-03 3.0 NaN 2020-01-04 2.0 NaN 2020-01-05 1.0 NaN ``` """ if not self.wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping required") checks.assert_in(order.upper(), ["C", "F"]) group_map = self.wrapper.grouper.get_group_map(group_by=group_by) if np.all(group_map[1] == group_map[1].item(0)): func = jit_reg.resolve_option(nb.flatten_uniform_grouped_nb, jitted) else: func = jit_reg.resolve_option(nb.flatten_grouped_nb, jitted) if order.upper() == "C": out = func(self.to_2d_array(), group_map, True) new_index = indexes.repeat_index(self.wrapper.index, np.max(group_map[1])) else: out = func(self.to_2d_array(), group_map, False) new_index = indexes.tile_index(self.wrapper.index, np.max(group_map[1])) wrap_kwargs = merge_dicts(dict(index=new_index), wrap_kwargs) return self.wrapper.wrap(out, group_by=group_by, **wrap_kwargs) # ############# Resampling ############# # def realign( self, index: tp.AnyRuleLike, freq: tp.Union[None, bool, tp.FrequencyLike] = None, nan_value: tp.Optional[tp.Scalar] = None, ffill: bool = True, source_rbound: tp.Union[bool, str, tp.IndexLike] = False, target_rbound: tp.Union[bool, str, tp.IndexLike] = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.MaybeSeriesFrame: """See `vectorbtpro.generic.nb.base.realign_nb`. `index` can be either an instance of `vectorbtpro.base.resampling.base.Resampler`, or any index-like object. Gives the same results as `df.resample(closed='right', label='right').last().ffill()` when applied on the target index of the resampler. Usage: * Downsampling: ```pycon >>> h_index = pd.date_range('2020-01-01', '2020-01-05', freq='1h') >>> d_index = pd.date_range('2020-01-01', '2020-01-05', freq='1d') >>> h_sr = pd.Series(range(len(h_index)), index=h_index) >>> h_sr.vbt.realign(d_index) 2020-01-01 0.0 2020-01-02 24.0 2020-01-03 48.0 2020-01-04 72.0 2020-01-05 96.0 Freq: D, dtype: float64 ``` * Upsampling: ```pycon >>> d_sr = pd.Series(range(len(d_index)), index=d_index) >>> d_sr.vbt.realign(h_index) 2020-01-01 00:00:00 0.0 2020-01-01 01:00:00 0.0 2020-01-01 02:00:00 0.0 2020-01-01 03:00:00 0.0 2020-01-01 04:00:00 0.0 ... ... 2020-01-04 20:00:00 3.0 2020-01-04 21:00:00 3.0 2020-01-04 22:00:00 3.0 2020-01-04 23:00:00 3.0 2020-01-05 00:00:00 4.0 Freq: H, Length: 97, dtype: float64 ``` """ resampler = self.wrapper.get_resampler( index, freq=freq, return_pd_resampler=False, silence_warnings=silence_warnings, ) one_index = False if len(resampler.target_index) == 1 and checks.is_dt_like(index): if isinstance(index, str): try: dt.to_freq(index) one_index = False except Exception as e: one_index = True else: one_index = True if isinstance(source_rbound, bool): use_source_rbound = source_rbound else: use_source_rbound = False if isinstance(source_rbound, str): if source_rbound == "pandas": resampler = resampler.replace(source_index=resampler.source_rbound_index) else: raise ValueError(f"Invalid source_rbound: '{source_rbound}'") else: resampler = resampler.replace(source_index=source_rbound) if isinstance(target_rbound, bool): use_target_rbound = target_rbound index = resampler.target_index else: use_target_rbound = False index = resampler.target_index if isinstance(target_rbound, str): if target_rbound == "pandas": resampler = resampler.replace(target_index=resampler.target_rbound_index) else: raise ValueError(f"Invalid target_rbound: '{target_rbound}'") else: resampler = resampler.replace(target_index=target_rbound) if not use_source_rbound: source_freq = None else: source_freq = resampler.get_np_source_freq() if not use_target_rbound: target_freq = None else: target_freq = resampler.get_np_target_freq() if nan_value is None: if self.mapping is not None: nan_value = -1 else: nan_value = np.nan func = jit_reg.resolve_option(nb.realign_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), resampler.source_index.values, resampler.target_index.values, source_freq=source_freq, target_freq=target_freq, source_rbound=use_source_rbound, target_rbound=use_target_rbound, nan_value=nan_value, ffill=ffill, ) wrap_kwargs = merge_dicts(dict(index=index), wrap_kwargs) out = self.wrapper.wrap(out, group_by=False, **wrap_kwargs) if one_index: return out.iloc[0] return out def realign_opening(self, *args, **kwargs) -> tp.MaybeSeriesFrame: """`GenericAccessor.realign` but creating a resampler and using the left bound of the source and target index.""" return self.realign(*args, source_rbound=False, target_rbound=False, **kwargs) def realign_closing(self, *args, **kwargs) -> tp.MaybeSeriesFrame: """`GenericAccessor.realign` but creating a resampler and using the right bound of the source and target index. !!! note The timestamps in the source and target index should denote the open time.""" return self.realign(*args, source_rbound=True, target_rbound=True, **kwargs) @hybrid_method def resample_to_index( cls_or_self, index: tp.AnyRuleLike, reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.RangeReduceMetaFunc], *args, freq: tp.Union[None, bool, tp.FrequencyLike] = None, before: bool = False, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, ) -> tp.SeriesFrame: """Resample solely based on target index. Applies `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_nb` on index ranges from `vectorbtpro.base.resampling.nb.map_index_to_source_ranges_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_meta_nb`. Usage: * Downsampling: ```pycon >>> h_index = pd.date_range('2020-01-01', '2020-01-05', freq='1h') >>> d_index = pd.date_range('2020-01-01', '2020-01-05', freq='1d') >>> h_sr = pd.Series(range(len(h_index)), index=h_index) >>> h_sr.vbt.resample_to_index(d_index, njit(lambda x: x.mean())) 2020-01-01 11.5 2020-01-02 35.5 2020-01-03 59.5 2020-01-04 83.5 2020-01-05 96.0 Freq: D, dtype: float64 >>> h_sr.vbt.resample_to_index(d_index, njit(lambda x: x.mean()), before=True) 2020-01-01 0.0 2020-01-02 12.5 2020-01-03 36.5 2020-01-04 60.5 2020-01-05 84.5 Freq: D, dtype: float64 ``` * Upsampling: ```pycon >>> d_sr = pd.Series(range(len(d_index)), index=d_index) >>> d_sr.vbt.resample_to_index(h_index, njit(lambda x: x[-1])) 2020-01-01 00:00:00 0.0 2020-01-01 01:00:00 NaN 2020-01-01 02:00:00 NaN 2020-01-01 03:00:00 NaN 2020-01-01 04:00:00 NaN ... ... 2020-01-04 20:00:00 NaN 2020-01-04 21:00:00 NaN 2020-01-04 22:00:00 NaN 2020-01-04 23:00:00 NaN 2020-01-05 00:00:00 4.0 Freq: H, Length: 97, dtype: float64 ``` * Using meta function: ```pycon >>> mean_ratio_meta_nb = njit(lambda from_i, to_i, col, a, b: \\ ... np.mean(a[from_i:to_i][col]) / np.mean(b[from_i:to_i][col])) >>> vbt.pd_acc.resample_to_index( ... d_index, ... mean_ratio_meta_nb, ... h_sr.vbt.to_2d_array() - 1, ... h_sr.vbt.to_2d_array() + 1, ... wrapper=h_sr.vbt.wrapper ... ) 2020-01-01 -1.000000 2020-01-02 0.920000 2020-01-03 0.959184 2020-01-04 0.972603 2020-01-05 0.979381 Freq: D, dtype: float64 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.resample_to_index( ... d_index, ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=h_sr - 1, ... b=h_sr + 1 ... ) ... ) 2020-01-01 -1.000000 2020-01-02 0.920000 2020-01-03 0.959184 2020-01-04 0.972603 2020-01-05 0.979381 Freq: D, dtype: float64 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) index = substitute_templates(index, template_context, eval_id="index") else: if wrapper is None: wrapper = cls_or_self.wrapper resampler = wrapper.get_resampler( index, freq=freq, return_pd_resampler=False, silence_warnings=silence_warnings, ) index_ranges = resampler.map_index_to_source_ranges(before=before, jitted=jitted) if isinstance(cls_or_self, type): template_context = merge_dicts( dict( resampler=resampler, index_ranges=index_ranges, ), template_context, ) args = substitute_templates(args, template_context, eval_id="args") func = jit_reg.resolve_option(nb.reduce_index_ranges_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( wrapper.shape_2d[1], index_ranges[0], index_ranges[1], reduce_func_nb, *args, ) else: func = jit_reg.resolve_option(nb.reduce_index_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( cls_or_self.to_2d_array(), index_ranges[0], index_ranges[1], reduce_func_nb, *args, ) wrap_kwargs = merge_dicts(dict(index=resampler.target_index), wrap_kwargs) return wrapper.wrap(out, group_by=False, **wrap_kwargs) @hybrid_method def resample_between_bounds( cls_or_self, target_lbound_index: tp.IndexLike, target_rbound_index: tp.IndexLike, reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.RangeReduceMetaFunc], *args, closed_lbound: bool = True, closed_rbound: bool = False, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_with_lbound: tp.Optional[bool] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Resample between target index bounds. Applies `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_nb` on index ranges from `vectorbtpro.base.resampling.nb.map_bounds_to_source_ranges_nb`. For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_meta_nb`. Usage: * Using regular function: ```pycon >>> h_index = pd.date_range('2020-01-01', '2020-01-05', freq='1h') >>> d_index = pd.date_range('2020-01-01', '2020-01-05', freq='1d') >>> h_sr = pd.Series(range(len(h_index)), index=h_index) >>> h_sr.vbt.resample_between_bounds(d_index, d_index.shift(), njit(lambda x: x.mean())) 2020-01-01 11.5 2020-01-02 35.5 2020-01-03 59.5 2020-01-04 83.5 2020-01-05 96.0 Freq: D, dtype: float64 ``` * Using meta function: ```pycon >>> mean_ratio_meta_nb = njit(lambda from_i, to_i, col, a, b: \\ ... np.mean(a[from_i:to_i][col]) / np.mean(b[from_i:to_i][col])) >>> vbt.pd_acc.resample_between_bounds( ... d_index, ... d_index.shift(), ... mean_ratio_meta_nb, ... h_sr.vbt.to_2d_array() - 1, ... h_sr.vbt.to_2d_array() + 1, ... wrapper=h_sr.vbt.wrapper ... ) 2020-01-01 -1.000000 2020-01-02 0.920000 2020-01-03 0.959184 2020-01-04 0.972603 2020-01-05 0.979381 Freq: D, dtype: float64 ``` * Using templates and broadcasting: ```pycon >>> vbt.pd_acc.resample_between_bounds( ... d_index, ... d_index.shift(), ... mean_ratio_meta_nb, ... vbt.Rep('a'), ... vbt.Rep('b'), ... broadcast_named_args=dict( ... a=h_sr - 1, ... b=h_sr + 1 ... ) ... ) 2020-01-01 -1.000000 2020-01-02 0.920000 2020-01-03 0.959184 2020-01-04 0.972603 2020-01-05 0.979381 Freq: D, dtype: float64 ``` """ if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb") if isinstance(cls_or_self, type): if len(broadcast_named_args) > 0: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) if wrapper is not None: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=wrapper.shape_2d, **broadcast_kwargs, ) else: broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: checks.assert_not_none(wrapper, arg_name="wrapper") template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context) target_lbound_index = substitute_templates( target_lbound_index, template_context, eval_id="target_lbound_index" ) target_rbound_index = substitute_templates( target_rbound_index, template_context, eval_id="target_rbound_index" ) else: if wrapper is None: wrapper = cls_or_self.wrapper target_lbound_index = dt.prepare_dt_index(target_lbound_index) target_rbound_index = dt.prepare_dt_index(target_rbound_index) if len(target_lbound_index) == 1 and len(target_rbound_index) > 1: target_lbound_index = repeat_index(target_lbound_index, len(target_rbound_index)) if wrap_with_lbound is None: wrap_with_lbound = False elif len(target_lbound_index) > 1 and len(target_rbound_index) == 1: target_rbound_index = repeat_index(target_rbound_index, len(target_lbound_index)) if wrap_with_lbound is None: wrap_with_lbound = True index_ranges = Resampler.map_bounds_to_source_ranges( source_index=wrapper.index.values, target_lbound_index=target_lbound_index.values, target_rbound_index=target_rbound_index.values, closed_lbound=closed_lbound, closed_rbound=closed_rbound, skip_not_found=False, jitted=jitted, ) if isinstance(cls_or_self, type): template_context = merge_dicts( dict( target_lbound_index=target_lbound_index, target_rbound_index=target_rbound_index, index_ranges=index_ranges, ), template_context, ) args = substitute_templates(args, template_context, eval_id="args") func = jit_reg.resolve_option(nb.reduce_index_ranges_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( wrapper.shape_2d[1], index_ranges[0], index_ranges[1], reduce_func_nb, *args, ) else: func = jit_reg.resolve_option(nb.reduce_index_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( cls_or_self.to_2d_array(), index_ranges[0], index_ranges[1], reduce_func_nb, *args, ) if wrap_with_lbound is None: if closed_lbound: wrap_with_lbound = True elif closed_rbound: wrap_with_lbound = False else: wrap_with_lbound = True if wrap_with_lbound: wrap_kwargs = merge_dicts(dict(index=target_lbound_index), wrap_kwargs) else: wrap_kwargs = merge_dicts(dict(index=target_rbound_index), wrap_kwargs) return wrapper.wrap(out, group_by=False, **wrap_kwargs) # ############# Describing ############# # def min( self, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return min of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="min"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.min_reduce_nb, jitted), flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nanmin_nb, jitted) elif arr.dtype != int and arr.dtype != float: # bottleneck can't consume other than that func = partial(np.nanmin, axis=0) else: func = partial(nanmin, axis=0) func = ch_reg.resolve_option(nb.nanmin_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs) def max( self, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return max of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="max"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.max_reduce_nb, jitted), flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nanmax_nb, jitted) elif arr.dtype != int and arr.dtype != float: # bottleneck can't consume other than that func = partial(np.nanmax, axis=0) else: func = partial(nanmax, axis=0) func = ch_reg.resolve_option(nb.nanmax_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs) def mean( self, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return mean of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="mean"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.mean_reduce_nb, jitted), flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nanmean_nb, jitted) elif arr.dtype != int and arr.dtype != float: # bottleneck can't consume other than that func = partial(np.nanmean, axis=0) else: func = partial(nanmean, axis=0) func = ch_reg.resolve_option(nb.nanmean_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs) def median( self, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return median of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="median"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.median_reduce_nb, jitted), flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nanmedian_nb, jitted) elif arr.dtype != int and arr.dtype != float: # bottleneck can't consume other than that func = partial(np.nanmedian, axis=0) else: func = partial(nanmedian, axis=0) func = ch_reg.resolve_option(nb.nanmedian_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs) def std( self, ddof: int = 1, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return standard deviation of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="std"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.std_reduce_nb, jitted), ddof, flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nanstd_nb, jitted) elif arr.dtype != int and arr.dtype != float: # bottleneck can't consume other than that func = partial(np.nanstd, axis=0) else: func = partial(nanstd, axis=0) func = ch_reg.resolve_option(nb.nanstd_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr, ddof=ddof), group_by=False, **wrap_kwargs) def sum( self, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return sum of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="sum"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.sum_reduce_nb, jitted), flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nansum_nb, jitted) elif arr.dtype != int and arr.dtype != float: # bottleneck can't consume other than that func = partial(np.nansum, axis=0) else: func = partial(nansum, axis=0) func = ch_reg.resolve_option(nb.nansum_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs) def count( self, use_jitted: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return count of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="count", dtype=int_), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.count_reduce_nb, jitted), flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) from vectorbtpro._settings import settings generic_cfg = settings["generic"] arr = self.to_2d_array() if use_jitted is None: use_jitted = generic_cfg["use_jitted"] if use_jitted: func = jit_reg.resolve_option(nb.nancnt_nb, jitted) else: func = lambda a: np.sum(~np.isnan(a), axis=0) func = ch_reg.resolve_option(nb.nancnt_nb, chunked, target_func=func) return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs) def cov( self, other: tp.SeriesFrame, ddof: int = 1, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return covariance of non-NaN elements.""" self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs)) self_arr = reshaping.to_2d_array(self_obj) other_arr = reshaping.to_2d_array(other_obj) wrap_kwargs = merge_dicts(dict(name_or_index="cov"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return type(self).reduce( jit_reg.resolve_option(nb.cov_reduce_grouped_meta_nb, jitted), self_arr, other_arr, ddof, flatten=True, jitted=jitted, chunked=chunked, wrapper=ArrayWrapper.from_obj(self_obj), group_by=self.wrapper.grouper.resolve_group_by(group_by=group_by), wrap_kwargs=wrap_kwargs, ) func = jit_reg.resolve_option(nb.nancov_nb, jitted) func = ch_reg.resolve_option(func, chunked) return self.wrapper.wrap_reduced(func(self_arr, other_arr, ddof=ddof), group_by=False, **wrap_kwargs) def corr( self, other: tp.SeriesFrame, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return correlation coefficient of non-NaN elements.""" self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs)) self_arr = reshaping.to_2d_array(self_obj) other_arr = reshaping.to_2d_array(other_obj) wrap_kwargs = merge_dicts(dict(name_or_index="corr"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return type(self).reduce( jit_reg.resolve_option(nb.corr_reduce_grouped_meta_nb, jitted), self_arr, other_arr, flatten=True, jitted=jitted, chunked=chunked, wrapper=ArrayWrapper.from_obj(self_obj), group_by=self.wrapper.grouper.resolve_group_by(group_by=group_by), wrap_kwargs=wrap_kwargs, ) func = jit_reg.resolve_option(nb.nancorr_nb, jitted) func = ch_reg.resolve_option(func, chunked) return self.wrapper.wrap_reduced(func(self_arr, other_arr), group_by=False, **wrap_kwargs) def rank( self, pct: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Compute numerical data rank. By default, equal values are assigned a rank that is the average of the ranks of those values.""" func = jit_reg.resolve_option(nb.rank_nb, jitted) func = ch_reg.resolve_option(func, chunked) arr = self.to_2d_array() argsorted = np.argsort(arr, axis=0) rank = func(arr, argsorted=argsorted, pct=pct) return self.wrapper.wrap(rank, group_by=False, **resolve_dict(wrap_kwargs)) def idxmin( self, order: str = "C", jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return labeled index of min of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="idxmin"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.argmin_reduce_nb, jitted), returns_idx=True, flatten=True, order=order, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) def func(arr, index): out = np.full(arr.shape[1], np.nan, dtype=object) nan_mask = np.all(np.isnan(arr), axis=0) out[~nan_mask] = index[nanargmin(arr[:, ~nan_mask], axis=0)] return out chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(index=None)) func = ch_reg.resolve_option(nb.nanmin_nb, chunked, target_func=func) out = func(self.to_2d_array(), self.wrapper.index) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def idxmax( self, order: str = "C", jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Return labeled index of max of non-NaN elements.""" wrap_kwargs = merge_dicts(dict(name_or_index="idxmax"), wrap_kwargs) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.argmax_reduce_nb, jitted), returns_idx=True, flatten=True, order=order, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) def func(arr, index): out = np.full(arr.shape[1], np.nan, dtype=object) nan_mask = np.all(np.isnan(arr), axis=0) out[~nan_mask] = index[nanargmax(arr[:, ~nan_mask], axis=0)] return out chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(index=None)) func = ch_reg.resolve_option(nb.nanmax_nb, chunked, target_func=func) out = func(self.to_2d_array(), self.wrapper.index) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def describe( self, percentiles: tp.Optional[tp.ArrayLike] = None, ddof: int = 1, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.apply_reduce.describe_reduce_nb`. For `percentiles`, see `pd.DataFrame.describe`. Usage: ```pycon >>> df.vbt.describe() a b c count 5.000000 5.000000 5.00000 mean 3.000000 3.000000 1.80000 std 1.581139 1.581139 0.83666 min 1.000000 1.000000 1.00000 25% 2.000000 2.000000 1.00000 50% 3.000000 3.000000 2.00000 75% 4.000000 4.000000 2.00000 max 5.000000 5.000000 3.00000 ``` """ if percentiles is not None: percentiles = reshaping.to_1d_array(percentiles) else: percentiles = np.array([0.25, 0.5, 0.75]) percentiles = percentiles.tolist() if 0.5 not in percentiles: percentiles.append(0.5) percentiles = np.unique(percentiles) perc_formatted = pd.io.formats.format.format_percentiles(percentiles) index = pd.Index(["count", "mean", "std", "min", *perc_formatted, "max"]) wrap_kwargs = merge_dicts(dict(name_or_index=index), wrap_kwargs) chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(args=ch.ArgsTaker(None, None))) if self.wrapper.grouper.is_grouped(group_by=group_by): return self.reduce( jit_reg.resolve_option(nb.describe_reduce_nb, jitted), percentiles, ddof, returns_array=True, flatten=True, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, ) else: return self.reduce( jit_reg.resolve_option(nb.describe_reduce_nb, jitted), percentiles, ddof, returns_array=True, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, ) def digitize( self, bins: tp.ArrayLike = "auto", right: bool = False, return_mapping: bool = False, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[tp.SeriesFrame, tp.Tuple[tp.SeriesFrame, dict]]: """Apply `np.digitize`. Usage: ```pycon >>> df.vbt.digitize(3) a b c 2020-01-01 1 3 1 2020-01-02 1 3 1 2020-01-03 2 2 2 2020-01-04 3 1 1 2020-01-05 3 1 1 ``` """ if wrap_kwargs is None: wrap_kwargs = {} arr = self.to_2d_array() if not np.iterable(bins): if np.isscalar(bins) and bins < 1: raise ValueError("Bins must be a positive integer") rng = (np.nanmin(self.obj.values), np.nanmax(self.obj.values)) mn, mx = (mi + 0.0 for mi in rng) if np.isinf(mn) or np.isinf(mx): raise ValueError("Cannot specify integer bins when input data contains infinity") elif mn == mx: # adjust end points before binning mn -= 0.001 * abs(mn) if mn != 0 else 0.001 mx += 0.001 * abs(mx) if mx != 0 else 0.001 bins = np.linspace(mn, mx, bins + 1, endpoint=True) else: # adjust end points after binning bins = np.linspace(mn, mx, bins + 1, endpoint=True) adj = (mx - mn) * 0.001 # 0.1% of the range if right: bins[0] -= adj else: bins[-1] += adj bin_edges = reshaping.to_1d_array(bins) mapping = dict() if right: out = np.digitize(arr, bin_edges[1:], right=right) if return_mapping: for i in range(len(bin_edges) - 1): mapping[i] = (bin_edges[i], bin_edges[i + 1]) else: out = np.digitize(arr, bin_edges[:-1], right=right) if return_mapping: for i in range(1, len(bin_edges)): mapping[i] = (bin_edges[i - 1], bin_edges[i]) if return_mapping: return self.wrapper.wrap(out, **wrap_kwargs), mapping return self.wrapper.wrap(out, **wrap_kwargs) def value_counts( self, axis: int = 1, normalize: bool = False, sort_uniques: bool = True, sort: bool = False, ascending: bool = False, dropna: bool = False, group_by: tp.GroupByLike = None, mapping: tp.Union[None, bool, tp.MappingLike] = None, incl_all_keys: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Return a Series/DataFrame containing counts of unique values. Args: axis (int): 0 - counts per row, 1 - counts per column, and -1 - counts across the whole object. normalize (bool): Whether to return the relative frequencies of the unique values. sort_uniques (bool): Whether to sort uniques. sort (bool): Whether to sort by frequency. ascending (bool): Whether to sort in ascending order. dropna (bool): Whether to exclude counts of NaN. group_by (any): Group or ungroup columns. See `vectorbtpro.base.grouping.base.Grouper`. mapping (mapping_like): Mapping of values to labels. incl_all_keys (bool): Whether to include all mapping keys, no only those that are present in the array. jitted (any): Whether to JIT-compile `vectorbtpro.generic.nb.base.value_counts_nb` or options. chunked (any): Whether to chunk `vectorbtpro.generic.nb.base.value_counts_nb` or options. See `vectorbtpro.utils.chunking.resolve_chunked`. wrap_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper.wrap`. **kwargs: Keyword arguments passed to `vectorbtpro.utils.mapping.apply_mapping`. Usage: ```pycon >>> df.vbt.value_counts() a b c 1 1 1 2 2 1 1 2 3 1 1 1 4 1 1 0 5 1 1 0 >>> df.vbt.value_counts(axis=-1) 1 4 2 4 3 3 4 2 5 2 Name: value_counts, dtype: int64 >>> mapping = {x: 'test_' + str(x) for x in pd.unique(df.values.flatten())} >>> df.vbt.value_counts(mapping=mapping) a b c test_1 1 1 2 test_2 1 1 2 test_3 1 1 1 test_4 1 1 0 test_5 1 1 0 >>> sr = pd.Series([1, 2, 2, 3, 3, 3, np.nan]) >>> sr.vbt.value_counts(mapping=mapping) test_1 1 test_2 2 test_3 3 NaN 1 dtype: int64 >>> sr.vbt.value_counts(mapping=mapping, dropna=True) test_1 1 test_2 2 test_3 3 dtype: int64 >>> sr.vbt.value_counts(mapping=mapping, sort=True) test_3 3 test_2 2 test_1 1 NaN 1 dtype: int64 >>> sr.vbt.value_counts(mapping=mapping, sort=True, ascending=True) test_1 1 NaN 1 test_2 2 test_3 3 dtype: int64 >>> sr.vbt.value_counts(mapping=mapping, incl_all_keys=True) test_1 1 test_2 2 test_3 3 test_4 0 test_5 0 NaN 1 dtype: int64 ``` """ checks.assert_in(axis, (-1, 0, 1)) mapping = self.resolve_mapping(mapping=mapping) codes, uniques = pd.factorize(self.obj.values.flatten(), sort=False, use_na_sentinel=False) if axis == 0: func = jit_reg.resolve_option(nb.value_counts_per_row_nb, jitted) func = ch_reg.resolve_option(func, chunked) value_counts = func(codes.reshape(self.wrapper.shape_2d), len(uniques)) elif axis == 1: group_map = self.wrapper.grouper.get_group_map(group_by=group_by) func = jit_reg.resolve_option(nb.value_counts_nb, jitted) func = ch_reg.resolve_option(func, chunked) value_counts = func(codes.reshape(self.wrapper.shape_2d), len(uniques), group_map) else: func = jit_reg.resolve_option(nb.value_counts_1d_nb, jitted) value_counts = func(codes, len(uniques)) if incl_all_keys and mapping is not None: missing_keys = [] for x in mapping: if pd.isnull(x) and pd.isnull(uniques).any(): continue if x not in uniques: missing_keys.append(x) if axis == 0 or axis == 1: value_counts = np.vstack((value_counts, np.full((len(missing_keys), value_counts.shape[1]), 0))) else: value_counts = np.concatenate((value_counts, np.full(len(missing_keys), 0))) uniques = np.concatenate((uniques, np.array(missing_keys))) nan_mask = np.isnan(uniques) if dropna: value_counts = value_counts[~nan_mask] uniques = uniques[~nan_mask] if sort_uniques: new_indices = uniques.argsort() value_counts = value_counts[new_indices] uniques = uniques[new_indices] if axis == 0 or axis == 1: value_counts_sum = value_counts.sum(axis=1) else: value_counts_sum = value_counts if normalize: value_counts = value_counts / value_counts_sum.sum() if sort: if ascending: new_indices = value_counts_sum.argsort() else: new_indices = (-value_counts_sum).argsort() value_counts = value_counts[new_indices] uniques = uniques[new_indices] if axis == 0: wrapper = ArrayWrapper.from_obj(value_counts) value_counts_pd = wrapper.wrap( value_counts, index=uniques, columns=self.wrapper.index, **resolve_dict(wrap_kwargs), ) elif axis == 1: value_counts_pd = self.wrapper.wrap( value_counts, index=uniques, group_by=group_by, **resolve_dict(wrap_kwargs), ) else: wrapper = ArrayWrapper.from_obj(value_counts) value_counts_pd = wrapper.wrap( value_counts, index=uniques, **merge_dicts(dict(columns=["value_counts"]), wrap_kwargs), ) if mapping is not None: value_counts_pd.index = apply_mapping(value_counts_pd.index, mapping, **kwargs) return value_counts_pd # ############# Transforming ############# # def demean( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.base.demean_nb`.""" func = jit_reg.resolve_option(nb.demean_nb, jitted) func = ch_reg.resolve_option(func, chunked) group_map = self.wrapper.grouper.get_group_map(group_by=group_by) out = func(self.to_2d_array(), group_map) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def transform(self, transformer: TransformerT, wrap_kwargs: tp.KwargsLike = None, **kwargs) -> tp.SeriesFrame: """Transform using a transformer. A transformer can be any class instance that has `transform` and `fit_transform` methods, ideally subclassing `sklearn.base.TransformerMixin` and `sklearn.base.BaseEstimator`. Will fit `transformer` if not fitted. `**kwargs` are passed to the `transform` or `fit_transform` method. Usage: ```pycon >>> from sklearn.preprocessing import MinMaxScaler >>> df.vbt.transform(MinMaxScaler((-1, 1))) a b c 2020-01-01 -1.0 1.0 -1.0 2020-01-02 -0.5 0.5 0.0 2020-01-03 0.0 0.0 1.0 2020-01-04 0.5 -0.5 0.0 2020-01-05 1.0 -1.0 -1.0 >>> fitted_scaler = MinMaxScaler((-1, 1)).fit(np.array([[2], [4]])) >>> df.vbt.transform(fitted_scaler) a b c 2020-01-01 -2.0 2.0 -2.0 2020-01-02 -1.0 1.0 -1.0 2020-01-03 0.0 0.0 0.0 2020-01-04 1.0 -1.0 -1.0 2020-01-05 2.0 -2.0 -2.0 ``` """ is_fitted = True try: check_is_fitted(transformer) except NotFittedError: is_fitted = False if not is_fitted: result = transformer.fit_transform(self.to_2d_array(), **kwargs) else: result = transformer.transform(self.to_2d_array(), **kwargs) return self.wrapper.wrap(result, group_by=False, **resolve_dict(wrap_kwargs)) def zscore(self, **kwargs) -> tp.SeriesFrame: """Compute z-score using `sklearn.preprocessing.StandardScaler`.""" return self.scale(with_mean=True, with_std=True, **kwargs) def rebase( self, base: float, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Rebase all series to the given base. This makes comparing/plotting different series together easier. Will forward and backward fill NaN values.""" func = jit_reg.resolve_option(nb.fbfill_nb, jitted) func = ch_reg.resolve_option(func, chunked) result = func(self.to_2d_array()) result = result / result[0] * base return self.wrapper.wrap(result, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Conversion ############# # def drawdown( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get drawdown series.""" func = jit_reg.resolve_option(nb.drawdown_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(self.to_2d_array()) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) @property def ranges(self) -> Ranges: """`GenericAccessor.get_ranges` with default arguments.""" return self.get_ranges() def get_ranges(self, *args, wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> Ranges: """Generate range records. See `vectorbtpro.generic.ranges.Ranges.from_array`.""" wrapper_kwargs = merge_dicts(self.wrapper.config, wrapper_kwargs) return Ranges.from_array(self.obj, *args, wrapper_kwargs=wrapper_kwargs, **kwargs) @property def drawdowns(self) -> Drawdowns: """`GenericAccessor.get_drawdowns` with default arguments.""" return self.get_drawdowns() def get_drawdowns(self, *args, **kwargs) -> Drawdowns: """Generate drawdown records. See `vectorbtpro.generic.drawdowns.Drawdowns.from_price`.""" return Drawdowns.from_price(self.obj, *args, wrapper=self.wrapper, **kwargs) def to_mapped( self, dropna: bool = True, dtype: tp.Optional[tp.DTypeLike] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> MappedArray: """Convert this object into an instance of `vectorbtpro.records.mapped_array.MappedArray`.""" mapped_arr = self.to_2d_array().flatten(order="F") col_arr = np.repeat(np.arange(self.wrapper.shape_2d[1]), self.wrapper.shape_2d[0]) idx_arr = np.tile(np.arange(self.wrapper.shape_2d[0]), self.wrapper.shape_2d[1]) if dropna and np.isnan(mapped_arr).any(): not_nan_mask = ~np.isnan(mapped_arr) mapped_arr = mapped_arr[not_nan_mask] col_arr = col_arr[not_nan_mask] idx_arr = idx_arr[not_nan_mask] return MappedArray( self.wrapper, np.asarray(mapped_arr, dtype=dtype), col_arr, idx_arr=idx_arr, **kwargs, ).regroup(group_by) def to_returns(self, **kwargs) -> tp.SeriesFrame: """Get returns of this object.""" from vectorbtpro.returns.accessors import ReturnsAccessor return ReturnsAccessor.from_value( self._obj, wrapper=self.wrapper, return_values=True, **kwargs, ) def to_log_returns(self, **kwargs) -> tp.SeriesFrame: """Get log returns of this object.""" from vectorbtpro.returns.accessors import ReturnsAccessor return ReturnsAccessor.from_value( self._obj, wrapper=self.wrapper, return_values=True, log_returns=True, **kwargs, ) def to_daily_returns(self, **kwargs) -> tp.SeriesFrame: """Get daily returns of this object.""" from vectorbtpro.returns.accessors import ReturnsAccessor return ReturnsAccessor.from_value( self._obj, wrapper=self.wrapper, return_values=False, **kwargs, ).daily() def to_daily_log_returns(self, **kwargs) -> tp.SeriesFrame: """Get daily log returns of this object.""" from vectorbtpro.returns.accessors import ReturnsAccessor return ReturnsAccessor.from_value( self._obj, wrapper=self.wrapper, return_values=False, log_returns=True, **kwargs, ).daily() # ############# Patterns ############# # def find_pattern(self, *args, **kwargs) -> PatternRanges: """Generate pattern range records. See `vectorbtpro.generic.ranges.PatternRanges.from_pattern_search`.""" return PatternRanges.from_pattern_search(self.obj, *args, **kwargs) # ############# Crossover ############# # def crossed_above( self, other: tp.ArrayLike, wait: int = 0, dropna: bool = False, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.base.crossed_above_nb`. Usage: ```pycon >>> df['b'].vbt.crossed_above(df['c']) 2020-01-01 False 2020-01-02 False 2020-01-03 False 2020-01-04 False 2020-01-05 False dtype: bool >>> df['a'].vbt.crossed_above(df['b']) 2020-01-01 False 2020-01-02 False 2020-01-03 False 2020-01-04 True 2020-01-05 False dtype: bool >>> df['a'].vbt.crossed_above(df['b'], wait=1) 2020-01-01 False 2020-01-02 False 2020-01-03 False 2020-01-04 False 2020-01-05 True dtype: bool ``` """ broadcastable_args = dict(obj=self.obj, other=other) broadcast_kwargs = merge_dicts(dict(keep_flex=dict(obj=False, other=True)), broadcast_kwargs) broadcasted_args, wrapper = reshaping.broadcast( broadcastable_args, to_pd=False, min_ndim=2, return_wrapper=True, **broadcast_kwargs, ) func = jit_reg.resolve_option(nb.crossed_above_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( broadcasted_args["obj"], broadcasted_args["other"], wait=wait, dropna=dropna, ) return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def crossed_below( self, other: tp.ArrayLike, wait: int = 0, dropna: bool = True, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.nb.base.crossed_below_nb`. Also, see `GenericAccessor.crossed_above` for similar examples.""" broadcastable_args = dict(obj=self.obj, other=other) broadcast_kwargs = merge_dicts(dict(keep_flex=dict(obj=False, other=True)), broadcast_kwargs) broadcasted_args, wrapper = reshaping.broadcast( broadcastable_args, to_pd=False, min_ndim=2, return_wrapper=True, **broadcast_kwargs, ) func = jit_reg.resolve_option(nb.crossed_below_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( broadcasted_args["obj"], broadcasted_args["other"], wait=wait, dropna=dropna, ) return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Resolution ############# # def resolve_self( self: GenericAccessorT, cond_kwargs: tp.KwargsLike = None, custom_arg_names: tp.Optional[tp.Set[str]] = None, impacts_caching: bool = True, silence_warnings: bool = False, ) -> GenericAccessorT: """Resolve self. See `vectorbtpro.base.wrapping.Wrapping.resolve_self`. Creates a copy of this instance `mapping` is different in `cond_kwargs`.""" if cond_kwargs is None: cond_kwargs = {} if custom_arg_names is None: custom_arg_names = set() reself = Wrapping.resolve_self( self, cond_kwargs=cond_kwargs, custom_arg_names=custom_arg_names, impacts_caching=impacts_caching, silence_warnings=silence_warnings, ) if "mapping" in cond_kwargs: self_copy = reself.replace(mapping=cond_kwargs["mapping"]) if not checks.is_deep_equal(self_copy.mapping, reself.mapping): if not silence_warnings: warn( f"Changing the mapping will create a copy of this object. " f"Consider setting it upon object creation to re-use existing cache." ) for alias in reself.self_aliases: if alias not in custom_arg_names: cond_kwargs[alias] = self_copy cond_kwargs["mapping"] = self_copy.mapping if impacts_caching: cond_kwargs["use_caching"] = False return self_copy return reself # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `GenericAccessor.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.generic`.""" from vectorbtpro._settings import settings generic_stats_cfg = settings["generic"]["stats"] return merge_dicts(Analyzable.stats_defaults.__get__(self), generic_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), count=dict(title="Count", calc_func="count", inv_check_has_mapping=True, tags=["generic", "describe"]), mean=dict(title="Mean", calc_func="mean", inv_check_has_mapping=True, tags=["generic", "describe"]), std=dict(title="Std", calc_func="std", inv_check_has_mapping=True, tags=["generic", "describe"]), min=dict(title="Min", calc_func="min", inv_check_has_mapping=True, tags=["generic", "describe"]), median=dict(title="Median", calc_func="median", inv_check_has_mapping=True, tags=["generic", "describe"]), max=dict(title="Max", calc_func="max", inv_check_has_mapping=True, tags=["generic", "describe"]), idx_min=dict( title="Min Index", calc_func="idxmin", agg_func=None, inv_check_has_mapping=True, tags=["generic", "index"], ), idx_max=dict( title="Max Index", calc_func="idxmax", agg_func=None, inv_check_has_mapping=True, tags=["generic", "index"], ), value_counts=dict( title="Value Counts", calc_func=lambda value_counts: reshaping.to_dict(value_counts, orient="index_series"), resolve_value_counts=True, check_has_mapping=True, tags=["generic", "value_counts"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, column: tp.Optional[tp.Label] = None, trace_names: tp.TraceNames = None, x_labels: tp.Optional[tp.Labels] = None, return_fig: bool = True, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Create `vectorbtpro.generic.plotting.Scatter` and return the figure. Usage: ```pycon >>> df.vbt.plot().show() ``` ![](/assets/images/api/df_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.generic.plotting import Scatter if column is not None: _self = self.select_col(column=column) else: _self = self if x_labels is None: x_labels = _self.wrapper.index if trace_names is None: if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None): trace_names = _self.wrapper.columns scatter = Scatter(data=_self.to_2d_array(), trace_names=trace_names, x_labels=x_labels, **kwargs) if return_fig: return scatter.fig return scatter def lineplot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """`GenericAccessor.plot` with 'lines' mode. Usage: ```pycon >>> df.vbt.lineplot().show() ``` ![](/assets/images/api/df_lineplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_lineplot.dark.svg#only-dark){: .iimg loading=lazy } """ return self.plot(column=column, **merge_dicts(dict(trace_kwargs=dict(mode="lines")), kwargs)) def scatterplot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """`GenericAccessor.plot` with 'markers' mode. Usage: ```pycon >>> df.vbt.scatterplot().show() ``` ![](/assets/images/api/df_scatterplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_scatterplot.dark.svg#only-dark){: .iimg loading=lazy } """ return self.plot(column=column, **merge_dicts(dict(trace_kwargs=dict(mode="markers")), kwargs)) def barplot( self, column: tp.Optional[tp.Label] = None, trace_names: tp.TraceNames = None, x_labels: tp.Optional[tp.Labels] = None, return_fig: bool = True, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Create `vectorbtpro.generic.plotting.Bar` and return the figure. Usage: ```pycon >>> df.vbt.barplot().show() ``` ![](/assets/images/api/df_barplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_barplot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.generic.plotting import Bar if column is not None: _self = self.select_col(column=column) else: _self = self if x_labels is None: x_labels = _self.wrapper.index if trace_names is None: if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None): trace_names = _self.wrapper.columns bar = Bar(data=_self.to_2d_array(), trace_names=trace_names, x_labels=x_labels, **kwargs) if return_fig: return bar.fig return bar def histplot( self, column: tp.Optional[tp.Label] = None, by_level: tp.Optional[tp.Level] = None, trace_names: tp.TraceNames = None, group_by: tp.GroupByLike = None, return_fig: bool = True, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Create `vectorbtpro.generic.plotting.Histogram` and return the figure. Usage: ```pycon >>> df.vbt.histplot().show() ``` ![](/assets/images/api/df_histplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_histplot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.generic.plotting import Histogram if by_level is not None: return self.obj.unstack(by_level).vbt.histplot( column=column, trace_names=trace_names, group_by=group_by, return_fig=return_fig, **kwargs, ) if column is not None: _self = self.select_col(column=column) else: _self = self if _self.wrapper.grouper.is_grouped(group_by=group_by): return _self.flatten_grouped(group_by=group_by).vbt.histplot(trace_names=trace_names, **kwargs) if trace_names is None: if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None): trace_names = _self.wrapper.columns hist = Histogram(data=_self.to_2d_array(), trace_names=trace_names, **kwargs) if return_fig: return hist.fig return hist def boxplot( self, column: tp.Optional[tp.Label] = None, by_level: tp.Optional[tp.Level] = None, trace_names: tp.TraceNames = None, group_by: tp.GroupByLike = None, return_fig: bool = True, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Create `vectorbtpro.generic.plotting.Box` and return the figure. Usage: ```pycon >>> df.vbt.boxplot().show() ``` ![](/assets/images/api/df_boxplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_boxplot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.generic.plotting import Box if by_level is not None: return self.obj.unstack(by_level).vbt.boxplot( column=column, trace_names=trace_names, group_by=group_by, return_fig=return_fig, **kwargs, ) if column is not None: _self = self.select_col(column=column) else: _self = self if _self.wrapper.grouper.is_grouped(group_by=group_by): return _self.flatten_grouped(group_by=group_by).vbt.boxplot(trace_names=trace_names, **kwargs) if trace_names is None: if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None): trace_names = _self.wrapper.columns box = Box(data=_self.to_2d_array(), trace_names=trace_names, **kwargs) if return_fig: return box.fig return box def plot_against( self, other: tp.ArrayLike, column: tp.Optional[tp.Label] = None, trace_kwargs: tp.KwargsLike = None, other_trace_kwargs: tp.Union[str, tp.KwargsLike] = None, pos_trace_kwargs: tp.KwargsLike = None, neg_trace_kwargs: tp.KwargsLike = None, hidden_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot Series as a line against another line. Args: other (array_like): Second array. Will broadcast. column (hashable): Column to plot. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`. other_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `other`. Set to 'hidden' to hide. pos_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for positive line. neg_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for negative line. hidden_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for hidden lines. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> df['a'].vbt.plot_against(df['b']).show() ``` ![](/assets/images/api/sr_plot_against.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/sr_plot_against.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure if trace_kwargs is None: trace_kwargs = {} if other_trace_kwargs is None: other_trace_kwargs = {} if pos_trace_kwargs is None: pos_trace_kwargs = {} if neg_trace_kwargs is None: neg_trace_kwargs = {} if hidden_trace_kwargs is None: hidden_trace_kwargs = {} obj = self.obj if isinstance(obj, pd.DataFrame): obj = self.select_col_from_obj(obj, column=column) if other is None: other = pd.Series.vbt.empty_like(obj, 1) else: other = reshaping.to_pd_array(other) if isinstance(other, pd.DataFrame): other = self.select_col_from_obj(other, column=column) obj, other = reshaping.broadcast(obj, other, columns_from="keep") if other.name is None: other = other.rename("Other") if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) # TODO: Using masks feels hacky pos_mask = obj > other if pos_mask.any(): # Fill positive area pos_obj = obj.copy() pos_obj[~pos_mask] = other[~pos_mask] other.vbt.lineplot( trace_kwargs=merge_dicts( dict( line=dict(color="rgba(0, 0, 0, 0)", width=0), opacity=0, hoverinfo="skip", showlegend=False, name=None, ), hidden_trace_kwargs, ), add_trace_kwargs=add_trace_kwargs, fig=fig, ) pos_obj.vbt.lineplot( trace_kwargs=merge_dicts( dict( fillcolor="rgba(0, 128, 0, 0.25)", line=dict(color="rgba(0, 0, 0, 0)", width=0), opacity=0, fill="tonexty", connectgaps=False, hoverinfo="skip", showlegend=False, name=None, ), pos_trace_kwargs, ), add_trace_kwargs=add_trace_kwargs, fig=fig, ) neg_mask = obj < other if neg_mask.any(): # Fill negative area neg_obj = obj.copy() neg_obj[~neg_mask] = other[~neg_mask] other.vbt.lineplot( trace_kwargs=merge_dicts( dict( line=dict(color="rgba(0, 0, 0, 0)", width=0), opacity=0, hoverinfo="skip", showlegend=False, name=None, ), hidden_trace_kwargs, ), add_trace_kwargs=add_trace_kwargs, fig=fig, ) neg_obj.vbt.lineplot( trace_kwargs=merge_dicts( dict( line=dict(color="rgba(0, 0, 0, 0)", width=0), fillcolor="rgba(255, 0, 0, 0.25)", opacity=0, fill="tonexty", connectgaps=False, hoverinfo="skip", showlegend=False, name=None, ), neg_trace_kwargs, ), add_trace_kwargs=add_trace_kwargs, fig=fig, ) # Plot main traces obj.vbt.lineplot(trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig) if other_trace_kwargs == "hidden": other_trace_kwargs = dict( line=dict(color="rgba(0, 0, 0, 0)", width=0), opacity=0.0, hoverinfo="skip", showlegend=False, name=None, ) other.vbt.lineplot(trace_kwargs=other_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig) return fig def overlay_with_heatmap( self, other: tp.ArrayLike, column: tp.Optional[tp.Label] = None, trace_kwargs: tp.KwargsLike = None, heatmap_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot Series as a line and overlays it with a heatmap. Args: other (array_like): Second array. Will broadcast. column (hashable): Column to plot. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`. heatmap_kwargs (dict): Keyword arguments passed to `GenericDFAccessor.heatmap`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> df['a'].vbt.overlay_with_heatmap(df['b']).show() ``` ![](/assets/images/api/sr_overlay_with_heatmap.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/sr_overlay_with_heatmap.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_subplots from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if trace_kwargs is None: trace_kwargs = {} if heatmap_kwargs is None: heatmap_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} obj = self.obj if isinstance(obj, pd.DataFrame): obj = self.select_col_from_obj(obj, column=column) if other is None: other = pd.Series.vbt.empty_like(obj, 1) else: other = reshaping.to_pd_array(other) if isinstance(other, pd.DataFrame): other = self.select_col_from_obj(other, column=column) obj, other = reshaping.broadcast(obj, other, columns_from="keep") if other.name is None: other = other.rename("Other") if fig is None: fig = make_subplots(specs=[[{"secondary_y": True}]]) if "width" in plotting_cfg["layout"]: fig.update_layout(width=plotting_cfg["layout"]["width"] + 100) fig.update_layout(**layout_kwargs) other.vbt.ts_heatmap(**heatmap_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig) obj.vbt.lineplot( trace_kwargs=merge_dicts(dict(line=dict(color=plotting_cfg["color_schema"]["blue"])), trace_kwargs), add_trace_kwargs=merge_dicts(dict(secondary_y=True), add_trace_kwargs), fig=fig, ) return fig def heatmap( self, column: tp.Optional[tp.Label] = None, x_level: tp.Optional[tp.Level] = None, y_level: tp.Optional[tp.Level] = None, symmetric: bool = False, sort: bool = True, x_labels: tp.Optional[tp.Labels] = None, y_labels: tp.Optional[tp.Labels] = None, slider_level: tp.Optional[tp.Level] = None, active: int = 0, slider_labels: tp.Optional[tp.Labels] = None, return_fig: bool = True, fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Create a heatmap figure based on object's multi-index and values. If the object is two-dimensional or the index is not a multi-index, returns a regular heatmap. If multi-index contains more than two levels or you want them in specific order, pass `x_level` and `y_level`, each (`int` if index or `str` if name) corresponding to an axis of the heatmap. Optionally, pass `slider_level` to use a level as a slider. Creates `vectorbtpro.generic.plotting.Heatmap` and returns the figure. Usage: * Plotting a figure based on a regular index: ```pycon >>> df = pd.DataFrame([ ... [0, np.nan, np.nan], ... [np.nan, 1, np.nan], ... [np.nan, np.nan, 2] ... ]) >>> df.vbt.heatmap().show() ``` ![](/assets/images/api/df_heatmap.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_heatmap.dark.svg#only-dark){: .iimg loading=lazy } * Plotting a figure based on a multi-index: ```pycon >>> multi_index = pd.MultiIndex.from_tuples([ ... (1, 1), ... (2, 2), ... (3, 3) ... ]) >>> sr = pd.Series(np.arange(len(multi_index)), index=multi_index) >>> sr 1 1 0 2 2 1 3 3 2 dtype: int64 >>> sr.vbt.heatmap().show() ``` ![](/assets/images/api/sr_heatmap.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/sr_heatmap.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.generic.plotting import Heatmap if column is not None: _self = self.select_col(column=column) else: _self = self if _self.ndim == 2 or not isinstance(self.wrapper.index, pd.MultiIndex): if x_labels is None: x_labels = _self.wrapper.columns if y_labels is None: y_labels = _self.wrapper.index heatmap = Heatmap( data=_self.to_2d_array(), x_labels=x_labels, y_labels=y_labels, fig=fig, **kwargs, ) if return_fig: return heatmap.fig return heatmap (x_level, y_level), (slider_level,) = indexes.pick_levels( _self.wrapper.index, required_levels=(x_level, y_level), optional_levels=(slider_level,), ) x_level_vals = _self.wrapper.index.get_level_values(x_level) y_level_vals = _self.wrapper.index.get_level_values(y_level) x_name = x_level_vals.name if x_level_vals.name is not None else "x" y_name = y_level_vals.name if y_level_vals.name is not None else "y" kwargs = merge_dicts( dict( trace_kwargs=dict( hovertemplate=f"{x_name}: %{{x}}
" + f"{y_name}: %{{y}}
" + "value: %{z}" ), xaxis_title=x_level_vals.name, yaxis_title=y_level_vals.name, ), kwargs, ) if slider_level is None: # No grouping df = _self.unstack_to_df( index_levels=y_level, column_levels=x_level, symmetric=symmetric, sort=sort, ) return df.vbt.heatmap( x_labels=x_labels, y_labels=y_labels, fig=fig, return_fig=return_fig, **kwargs, ) # Requires grouping # See https://plotly.com/python/sliders/ if not return_fig: raise ValueError("Cannot use return_fig=False and slider_level simultaneously") _slider_labels = [] for i, (name, group) in enumerate(_self.obj.groupby(level=slider_level)): if slider_labels is not None: name = slider_labels[i] _slider_labels.append(name) df = group.vbt.unstack_to_df( index_levels=y_level, column_levels=x_level, symmetric=symmetric, sort=sort, ) if x_labels is None: x_labels = df.columns if y_labels is None: y_labels = df.index _kwargs = merge_dicts( dict( trace_kwargs=dict(name=str(name) if name is not None else None, visible=False), ), kwargs, ) default_size = fig is None and "height" not in _kwargs fig = Heatmap( data=reshaping.to_2d_array(df), x_labels=x_labels, y_labels=y_labels, fig=fig, **_kwargs, ).fig if default_size: fig.layout["height"] += 100 # slider takes up space fig.data[active].visible = True steps = [] for i in range(len(fig.data)): step = dict( method="update", args=[{"visible": [False] * len(fig.data)}, {}], label=str(_slider_labels[i]) if _slider_labels[i] is not None else None, ) step["args"][0]["visible"][i] = True steps.append(step) prefix = ( f"{_self.wrapper.index.names[slider_level]}: " if _self.wrapper.index.names[slider_level] is not None else None ) sliders = [ dict( active=active, currentvalue={"prefix": prefix}, pad={"t": 50}, steps=steps, ) ] fig.update_layout(sliders=sliders) return fig def ts_heatmap( self, column: tp.Optional[tp.Label] = None, is_y_category: bool = True, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Heatmap of time-series data.""" if column is not None: obj = self.select_col_from_obj(self.obj, column=column) else: obj = self.obj if isinstance(obj, pd.Series): obj = obj.to_frame() return obj.transpose().iloc[::-1].vbt.heatmap(is_y_category=is_y_category, **kwargs) def volume( self, column: tp.Optional[tp.Label] = None, x_level: tp.Optional[tp.Level] = None, y_level: tp.Optional[tp.Level] = None, z_level: tp.Optional[tp.Level] = None, x_labels: tp.Optional[tp.Labels] = None, y_labels: tp.Optional[tp.Labels] = None, z_labels: tp.Optional[tp.Labels] = None, slider_level: tp.Optional[tp.Level] = None, slider_labels: tp.Optional[tp.Labels] = None, active: int = 0, scene_name: str = "scene", fillna: tp.Optional[tp.Number] = None, fig: tp.Optional[tp.BaseFigure] = None, return_fig: bool = True, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Create a 3D volume figure based on object's multi-index and values. If multi-index contains more than three levels or you want them in specific order, pass `x_level`, `y_level`, and `z_level`, each (`int` if index or `str` if name) corresponding to an axis of the volume. Optionally, pass `slider_level` to use a level as a slider. Creates `vectorbtpro.generic.plotting.Volume` and returns the figure. Usage: ```pycon >>> multi_index = pd.MultiIndex.from_tuples([ ... (1, 1, 1), ... (2, 2, 2), ... (3, 3, 3) ... ]) >>> sr = pd.Series(np.arange(len(multi_index)), index=multi_index) >>> sr 1 1 1 0 2 2 2 1 3 3 3 2 dtype: int64 >>> sr.vbt.volume().show() ``` ![](/assets/images/api/sr_volume.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/sr_volume.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.generic.plotting import Volume self_col = self.select_col(column=column) (x_level, y_level, z_level), (slider_level,) = indexes.pick_levels( self_col.wrapper.index, required_levels=(x_level, y_level, z_level), optional_levels=(slider_level,), ) x_level_vals = self_col.wrapper.index.get_level_values(x_level) y_level_vals = self_col.wrapper.index.get_level_values(y_level) z_level_vals = self_col.wrapper.index.get_level_values(z_level) # Labels are just unique level values if x_labels is None: x_labels = np.unique(x_level_vals) if y_labels is None: y_labels = np.unique(y_level_vals) if z_labels is None: z_labels = np.unique(z_level_vals) x_name = x_level_vals.name if x_level_vals.name is not None else "x" y_name = y_level_vals.name if y_level_vals.name is not None else "y" z_name = z_level_vals.name if z_level_vals.name is not None else "z" def_kwargs = dict() def_kwargs["trace_kwargs"] = dict( hovertemplate=f"{x_name}: %{{x}}
" + f"{y_name}: %{{y}}
" + f"{z_name}: %{{z}}
" + "value: %{value}" ) def_kwargs[scene_name] = dict( xaxis_title=x_level_vals.name, yaxis_title=y_level_vals.name, zaxis_title=z_level_vals.name, ) def_kwargs["scene_name"] = scene_name kwargs = merge_dicts(def_kwargs, kwargs) contains_nan = False if slider_level is None: # No grouping v = self_col.unstack_to_array(levels=(x_level, y_level, z_level)) if fillna is not None: v = np.nan_to_num(v, nan=fillna) if np.isnan(v).any(): contains_nan = True volume = Volume(data=v, x_labels=x_labels, y_labels=y_labels, z_labels=z_labels, fig=fig, **kwargs) if return_fig: fig = volume.fig else: fig = volume else: # Requires grouping # See https://plotly.com/python/sliders/ if not return_fig: raise ValueError("Cannot use return_fig=False and slider_level simultaneously") _slider_labels = [] for i, (name, group) in enumerate(self_col.obj.groupby(level=slider_level)): if slider_labels is not None: name = slider_labels[i] _slider_labels.append(name) v = group.vbt.unstack_to_array(levels=(x_level, y_level, z_level)) if fillna is not None: v = np.nan_to_num(v, nan=fillna) if np.isnan(v).any(): contains_nan = True _kwargs = merge_dicts( dict(trace_kwargs=dict(name=str(name) if name is not None else None, visible=False)), kwargs, ) default_size = fig is None and "height" not in _kwargs fig = Volume(data=v, x_labels=x_labels, y_labels=y_labels, z_labels=z_labels, fig=fig, **_kwargs).fig if default_size: fig.layout["height"] += 100 # slider takes up space fig.data[active].visible = True steps = [] for i in range(len(fig.data)): step = dict( method="update", args=[{"visible": [False] * len(fig.data)}, {}], label=str(_slider_labels[i]) if _slider_labels[i] is not None else None, ) step["args"][0]["visible"][i] = True steps.append(step) prefix = ( f"{self_col.wrapper.index.names[slider_level]}: " if self_col.wrapper.index.names[slider_level] is not None else None ) sliders = [dict(active=active, currentvalue={"prefix": prefix}, pad={"t": 50}, steps=steps)] fig.update_layout(sliders=sliders) if contains_nan: warn("Data contains NaNs. Use `fillna` argument or `show` method in case of visualization issues.") return fig def qqplot( self, column: tp.Optional[tp.Label] = None, sparams: tp.Union[tp.Iterable, tuple, None] = (), dist: str = "norm", plot_line: bool = True, line_shape_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.BaseFigure: """Plot probability plot using `scipy.stats.probplot`. `**kwargs` are passed to `GenericAccessor.scatterplot`. Usage: ```pycon >>> pd.Series(np.random.standard_normal(100)).vbt.qqplot().show() ``` ![](/assets/images/api/sr_qqplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/sr_qqplot.dark.svg#only-dark){: .iimg loading=lazy } """ import scipy.stats as st obj = self.select_col_from_obj(self.obj, column=column) qq = st.probplot(obj, sparams=sparams, dist=dist) fig = pd.Series(qq[0][1], index=qq[0][0]).vbt.scatterplot(fig=fig, **kwargs) if plot_line: if line_shape_kwargs is None: line_shape_kwargs = {} x = np.array([qq[0][0][0], qq[0][0][-1]]) y = qq[1][1] + qq[1][0] * x fig.add_shape( **merge_dicts( dict(type="line", xref=xref, yref=yref, x0=x[0], y0=y[0], x1=x[1], y1=y[1], line=dict(color="red")), line_shape_kwargs, ) ) return fig def areaplot( self, line_shape: str = "spline", line_visible: bool = False, colorway: tp.Union[None, str, tp.Sequence[str]] = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot stacked area. Args: line_shape (str): Line shape. line_visible (bool): Whether to make line visible. colorway (str or sequence): Name of the built-in, qualitative colorway, or a list with colors. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> df.vbt.areaplot().show() ``` ![](/assets/images/api/df_areaplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_areaplot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure import plotly.express as px if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if colorway is None: if fig.layout.colorway is not None: colorway = fig.layout.colorway else: colorway = fig.layout.template.layout.colorway if len(self.wrapper.columns) > len(colorway): colorway = px.colors.qualitative.Alphabet elif isinstance(colorway, str): colorway = getattr(px.colors.qualitative, colorway) pos_mask = self.obj.values > 0 pos_mask_any = pos_mask.any() neg_mask = self.obj.values < 0 neg_mask_any = neg_mask.any() pos_showlegend = False neg_showlegend = False if pos_mask_any: pos_showlegend = True elif neg_mask_any: neg_showlegend = True line_width = None if line_visible else 0 line_opacity = 0.3 if line_visible else 0.8 if pos_mask_any: pos_df = self.obj.copy() pos_df[neg_mask] = 0.0 fig = pos_df.vbt.lineplot( trace_kwargs=[ merge_dicts( dict( legendgroup="area_" + str(c), stackgroup="one", line=dict(width=line_width, color=colorway[c % len(colorway)], shape=line_shape), fillcolor=adjust_opacity(colorway[c % len(colorway)], line_opacity), showlegend=pos_showlegend, ), resolve_dict(trace_kwargs, i=c), ) for c in range(len(self.wrapper.columns)) ], add_trace_kwargs=add_trace_kwargs, use_gl=False, fig=fig, **layout_kwargs, ) if neg_mask_any: neg_df = self.obj.copy() neg_df[pos_mask] = 0.0 fig = neg_df.vbt.lineplot( trace_kwargs=[ merge_dicts( dict( legendgroup="area_" + str(c), stackgroup="two", line=dict(width=line_width, color=colorway[c % len(colorway)], shape=line_shape), fillcolor=adjust_opacity(colorway[c % len(colorway)], line_opacity), showlegend=neg_showlegend, ), resolve_dict(trace_kwargs, i=c), ) for c in range(len(self.wrapper.columns)) ], add_trace_kwargs=add_trace_kwargs, use_gl=False, fig=fig, **layout_kwargs, ) return fig def plot_pattern( self, pattern: tp.ArrayLike, interp_mode: tp.Union[int, str] = "mixed", rescale_mode: tp.Union[int, str] = "minmax", vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: tp.Union[int, str] = "absolute", max_error: tp.ArrayLike = np.nan, max_error_interp_mode: tp.Union[None, int, str] = None, column: tp.Optional[tp.Label] = None, plot_obj: bool = True, fill_distance: bool = False, obj_trace_kwargs: tp.KwargsLike = None, pattern_trace_kwargs: tp.KwargsLike = None, lower_max_error_trace_kwargs: tp.KwargsLike = None, upper_max_error_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot pattern. Mimics the same similarity calculation procedure as implemented in `vectorbtpro.generic.nb.patterns.pattern_similarity_nb`. Usage: ```pycon >>> sr = pd.Series([10, 11, 12, 13, 12, 13, 14, 15, 13, 14, 11]) >>> sr.vbt.plot_pattern([1, 2, 3, 2, 1]).show() ``` ![](/assets/images/api/sr_plot_pattern.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/sr_plot_pattern.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if isinstance(interp_mode, str): interp_mode = map_enum_fields(interp_mode, InterpMode) if isinstance(rescale_mode, str): rescale_mode = map_enum_fields(rescale_mode, RescaleMode) if isinstance(error_type, str): error_type = map_enum_fields(error_type, ErrorType) if max_error_interp_mode is not None and isinstance(max_error_interp_mode, str): max_error_interp_mode = map_enum_fields(max_error_interp_mode, InterpMode) if max_error_interp_mode is None: max_error_interp_mode = interp_mode obj_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"])), obj_trace_kwargs, ) if pattern_trace_kwargs is None: pattern_trace_kwargs = {} if lower_max_error_trace_kwargs is None: lower_max_error_trace_kwargs = {} if upper_max_error_trace_kwargs is None: upper_max_error_trace_kwargs = {} if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) self_col = self.select_col(column=column) if plot_obj: # Plot object fig = self_col.lineplot( trace_kwargs=obj_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) # Reconstruct pattern and max error bands pattern_sr, max_error_sr = self_col.fit_pattern( pattern, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=vmin, vmax=vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, max_error=max_error, max_error_interp_mode=max_error_interp_mode, ) # Plot pattern and max error bands def_pattern_trace_kwargs = dict( name=f"Pattern", connectgaps=True, ) if interp_mode == InterpMode.Discrete: _pattern_trace_kwargs = merge_dicts( def_pattern_trace_kwargs, dict( mode="lines+markers", marker=dict(color=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.75)), line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.75), dash="dot"), ), pattern_trace_kwargs, ) else: if fill_distance: _pattern_trace_kwargs = merge_dicts( def_pattern_trace_kwargs, dict( mode="lines", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.75)), fill="tonexty", fillcolor=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.25), ), pattern_trace_kwargs, ) else: _pattern_trace_kwargs = merge_dicts( def_pattern_trace_kwargs, dict( mode="lines", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.75), dash="dot"), ), pattern_trace_kwargs, ) fig = pattern_sr.rename(None).vbt.plot( trace_kwargs=_pattern_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) # Plot max error bounds if not np.isnan(max_error).all(): def_max_error_trace_kwargs = dict( name="Max error", connectgaps=True, ) if max_error_interp_mode == InterpMode.Discrete: _lower_max_error_trace_kwargs = merge_dicts( def_max_error_trace_kwargs, dict( mode="markers+lines", marker=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5)), line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5), dash="dot"), ), lower_max_error_trace_kwargs, ) _upper_max_error_trace_kwargs = merge_dicts( def_max_error_trace_kwargs, dict( mode="markers+lines", marker=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5)), line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5), dash="dot"), showlegend=False, ), upper_max_error_trace_kwargs, ) else: _lower_max_error_trace_kwargs = merge_dicts( def_max_error_trace_kwargs, dict( mode="lines", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5), dash="dot"), ), lower_max_error_trace_kwargs, ) _upper_max_error_trace_kwargs = merge_dicts( def_max_error_trace_kwargs, dict( mode="lines", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5), dash="dot"), fillcolor=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.1), fill="tonexty", showlegend=False, ), upper_max_error_trace_kwargs, ) fig = ( (pattern_sr - max_error_sr) .rename(None) .vbt.plot( trace_kwargs=_lower_max_error_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) ) fig = ( (pattern_sr + max_error_sr) .rename(None) .vbt.plot( trace_kwargs=_upper_max_error_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `GenericAccessor.plots`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.generic`.""" from vectorbtpro._settings import settings generic_plots_cfg = settings["generic"]["plots"] return merge_dicts(Analyzable.plots_defaults.__get__(self), generic_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( check_is_not_grouped=True, plot_func="plot", pass_trace_names=False, tags="generic", ) ), ) @property def subplots(self) -> Config: return self._subplots if settings["importing"]["sklearn"]: from sklearn.exceptions import NotFittedError from sklearn.preprocessing import ( Binarizer, MinMaxScaler, MaxAbsScaler, Normalizer, RobustScaler, StandardScaler, QuantileTransformer, PowerTransformer, ) from sklearn.utils.validation import check_is_fitted transform_config = ReadonlyConfig( { "binarize": dict(transformer=Binarizer, docstring="See `sklearn.preprocessing.Binarizer`."), "minmax_scale": dict(transformer=MinMaxScaler, docstring="See `sklearn.preprocessing.MinMaxScaler`."), "maxabs_scale": dict(transformer=MaxAbsScaler, docstring="See `sklearn.preprocessing.MaxAbsScaler`."), "normalize": dict(transformer=Normalizer, docstring="See `sklearn.preprocessing.Normalizer`."), "robust_scale": dict(transformer=RobustScaler, docstring="See `sklearn.preprocessing.RobustScaler`."), "scale": dict(transformer=StandardScaler, docstring="See `sklearn.preprocessing.StandardScaler`."), "quantile_transform": dict( transformer=QuantileTransformer, docstring="See `sklearn.preprocessing.QuantileTransformer`.", ), "power_transform": dict( transformer=PowerTransformer, docstring="See `sklearn.preprocessing.PowerTransformer`.", ), } ) """_""" __pdoc__[ "transform_config" ] = f"""Config of transform methods to be attached to `GenericAccessor`. ```python {transform_config.prettify()} ``` """ GenericAccessor = attach_transform_methods(transform_config)(GenericAccessor) GenericAccessor.override_metrics_doc(__pdoc__) GenericAccessor.override_subplots_doc(__pdoc__) class GenericSRAccessor(GenericAccessor, BaseSRAccessor): """Accessor on top of data of any type. For Series only. Accessible via `pd.Series.vbt`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, mapping: tp.Optional[tp.MappingLike] = None, _full_init: bool = True, **kwargs, ) -> None: BaseSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: GenericAccessor.__init__(self, wrapper, obj=obj, mapping=mapping, **kwargs) def fit_pattern( self, pattern: tp.ArrayLike, interp_mode: tp.Union[int, str] = "mixed", rescale_mode: tp.Union[int, str] = "minmax", vmin: float = np.nan, vmax: float = np.nan, pmin: float = np.nan, pmax: float = np.nan, invert: bool = False, error_type: tp.Union[int, str] = "absolute", max_error: tp.ArrayLike = np.nan, max_error_interp_mode: tp.Union[None, int, str] = None, jitted: tp.JittedOption = None, ) -> tp.Tuple[tp.Series, tp.Series]: """See `vectorbtpro.generic.nb.patterns.fit_pattern_nb`.""" if isinstance(interp_mode, str): interp_mode = map_enum_fields(interp_mode, InterpMode) if isinstance(rescale_mode, str): rescale_mode = map_enum_fields(rescale_mode, RescaleMode) if isinstance(error_type, str): error_type = map_enum_fields(error_type, ErrorType) if max_error_interp_mode is not None and isinstance(max_error_interp_mode, str): max_error_interp_mode = map_enum_fields(max_error_interp_mode, InterpMode) if max_error_interp_mode is None: max_error_interp_mode = interp_mode pattern = reshaping.to_1d_array(pattern) max_error = reshaping.broadcast_array_to(max_error, len(pattern)) func = jit_reg.resolve_option(nb.fit_pattern_nb, jitted) fit_pattern, fit_max_error = func( self.to_1d_array(), pattern, interp_mode=interp_mode, rescale_mode=rescale_mode, vmin=vmin, vmax=vmax, pmin=pmin, pmax=pmax, invert=invert, error_type=error_type, max_error=max_error, max_error_interp_mode=max_error_interp_mode, ) pattern_sr = self.wrapper.wrap(fit_pattern) max_error_sr = self.wrapper.wrap(fit_max_error) return pattern_sr, max_error_sr def to_renko( self, brick_size: tp.ArrayLike, relative: tp.ArrayLike = False, start_value: tp.Optional[float] = None, max_out_len: tp.Optional[int] = None, reset_index: bool = False, return_uptrend: bool = False, jitted: tp.JittedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[tp.Series, tp.Tuple[tp.Series, tp.Series]]: """See `vectorbtpro.generic.nb.base.to_renko_1d_nb`.""" func = jit_reg.resolve_option(nb.to_renko_1d_nb, jitted) arr_out, idx_out, uptrend_out = func( self.to_1d_array(), reshaping.broadcast_array_to(brick_size, self.wrapper.shape[0]), relative=reshaping.broadcast_array_to(relative, self.wrapper.shape[0]), start_value=start_value, max_out_len=max_out_len, ) if reset_index: new_index = pd.RangeIndex(stop=len(idx_out)) else: new_index = self.wrapper.index[idx_out] wrap_kwargs = merge_dicts( dict(index=new_index), wrap_kwargs, ) sr_out = self.wrapper.wrap(arr_out, group_by=False, **wrap_kwargs) if return_uptrend: uptrend_out = self.wrapper.wrap(uptrend_out, group_by=False, **wrap_kwargs) return sr_out, uptrend_out return sr_out def to_renko_ohlc( self, brick_size: tp.ArrayLike, relative: tp.ArrayLike = False, start_value: tp.Optional[float] = None, max_out_len: tp.Optional[int] = None, reset_index: bool = False, jitted: tp.JittedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Frame: """See `vectorbtpro.generic.nb.base.to_renko_ohlc_1d_nb`.""" func = jit_reg.resolve_option(nb.to_renko_ohlc_1d_nb, jitted) arr_out, idx_out = func( self.to_1d_array(), reshaping.broadcast_array_to(brick_size, self.wrapper.shape[0]), relative=reshaping.broadcast_array_to(relative, self.wrapper.shape[0]), start_value=start_value, max_out_len=max_out_len, ) if reset_index: new_index = pd.RangeIndex(stop=len(idx_out)) else: new_index = self.wrapper.index[idx_out] wrap_kwargs = merge_dicts( dict(index=new_index, columns=["Open", "High", "Low", "Close"]), wrap_kwargs, ) return self.wrapper.wrap(arr_out, group_by=False, **wrap_kwargs) class GenericDFAccessor(GenericAccessor, BaseDFAccessor): """Accessor on top of data of any type. For DataFrames only. Accessible via `pd.DataFrame.vbt`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, mapping: tp.Optional[tp.MappingLike] = None, _full_init: bool = True, **kwargs, ) -> None: BaseDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: GenericAccessor.__init__(self, wrapper, obj=obj, mapping=mapping, **kwargs) def band(self, band_name: str, return_meta: bool = False) -> tp.Union[tp.Series, dict]: """Calculate the band by its name. Examples for the band name: * "50%": 50th quantile * "Q=50%": 50th quantile * "Q=0.5": 50th quantile * "Z=1.96": Z-score of 1.96 * "P=95%": One-tailed significance level of 0.95 (translated into z-score) * "P=0.95": One-tailed significance level of 0.95 (translated into z-score) * "median": Median (50th quantile) * "mean": Mean across all columns * "min": Min across all columns * "max": Max across all columns * "lowest": Column with the lowest final value * "highest": Column with the highest final value """ band_name = band_name.lower().replace(" ", "") if band_name == "median": band_name = "50%" if "%" in band_name and not band_name.startswith("q=") and not band_name.startswith("p="): band_name = f"q={band_name}" if band_name.startswith("q="): if "%" in band_name: q = float(band_name.replace("q=", "").replace("%", "")) / 100 else: q = float(band_name.replace("q=", "")) q_readable = np.around(q * 100, decimals=2) if q_readable.is_integer(): q_readable = int(q_readable) band_title = f"Q={q_readable}% (proj)" def band_func(df, _q=q): return df.quantile(_q, axis=1) elif band_name.startswith("z="): z = float(band_name.replace("z=", "")) z_readable = np.around(z, decimals=2) if z_readable.is_integer(): z_readable = int(z_readable) band_title = f"Z={z_readable} (proj)" def band_func(df, _z=z): return df.mean(axis=1) + _z * df.std(axis=1) elif band_name.startswith("p="): import scipy.stats as st if "%" in band_name: p = float(band_name.replace("p=", "").replace("%", "")) / 100 else: p = float(band_name.replace("p=", "")) p_readable = np.around(p * 100, decimals=2) if p_readable.is_integer(): p_readable = int(p_readable) band_title = f"P={p_readable}% (proj)" z = st.norm.ppf(p) def band_func(df, _z=z): return df.mean(axis=1) + _z * df.std(axis=1) elif band_name == "mean": band_title = "Mean (proj)" def band_func(df): return df.mean(axis=1) elif band_name == "min": band_title = "Min (proj)" def band_func(df): return df.min(axis=1) elif band_name == "max": band_title = "Max (proj)" def band_func(df): return df.max(axis=1) elif band_name == "lowest": band_title = "Lowest (proj)" def band_func(df): return df[df.ffill().iloc[-1].idxmin()] elif band_name == "highest": band_title = "Highest (proj)" def band_func(df): return df[df.ffill().iloc[-1].idxmax()] else: raise ValueError(f"Invalid band_name: '{band_name}'") if return_meta: return dict(band_name=band_name, band_title=band_title, band_func=band_func) return band_func(self.obj) def plot_projections( self, plot_projections: bool = True, plot_bands: bool = True, plot_lower: tp.Union[bool, str, tp.Callable] = True, plot_middle: tp.Union[bool, str, tp.Callable] = True, plot_upper: tp.Union[bool, str, tp.Callable] = True, plot_aux_middle: tp.Union[bool, str, tp.Callable] = True, plot_fill: bool = True, colorize: tp.Union[bool, str, tp.Callable] = True, rename_levels: tp.Union[None, dict, tp.Sequence] = None, projection_trace_kwargs: tp.KwargsLike = None, upper_trace_kwargs: tp.KwargsLike = None, middle_trace_kwargs: tp.KwargsLike = None, lower_trace_kwargs: tp.KwargsLike = None, aux_middle_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot a DataFrame where each column is a projection. If `plot_projections` is True, will plot each projection as a semi-transparent line. The arguments `plot_lower`, `plot_middle`, `plot_aux_middle`, and `plot_upper` represent bands and accept the following: * True: Plot the band using the default quantile (20/50/80) * False: Do not plot the band * callable: Custom function that accepts DataFrame and reduces it across columns * For other options see `GenericDFAccessor.band` !!! note When providing z-scores, the upper should be positive, the middle should be "mean", and the lower should be negative. When providing significance levels, the middle should be "mean", while the lower should be positive and lower than the upper, for example, 25% and 75%. Argument `colorize` allows the following values: * False: Do not colorize * True or "median": Colorize by median * "mean": Colorize by mean * "last": Colorize by last value * callable: Custom function that accepts (rebased to 0) Series/DataFrame with nans already dropped and reduces it across rows Colorization is performed by mapping the metric value of the band to the range between the minimum and maximum value across all projections where 0 is always the middle point. If none of the bands is plotted, projections got colorized. Otherwise, projections stay gray. Usage: ```pycon >>> df = pd.DataFrame({ ... 0: [10, 11, 12, 11, 10], ... 1: [10, 12, 14, np.nan, np.nan], ... 2: [10, 12, 11, 12, np.nan], ... 3: [10, 9, 8, 9, 8], ... 4: [10, 11, np.nan, np.nan, np.nan], ... }) >>> df.vbt.plot_projections().show() ``` ![](/assets/images/api/df_plot_projections.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/df_plot_projections.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if projection_trace_kwargs is None: projection_trace_kwargs = {} if lower_trace_kwargs is None: lower_trace_kwargs = {} if upper_trace_kwargs is None: upper_trace_kwargs = {} if middle_trace_kwargs is None: middle_trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} # Resolve band functions and names if len(self.obj.columns) == 1: plot_bands = False if not plot_bands: plot_lower = False plot_middle = False plot_upper = False plot_aux_middle = False if isinstance(plot_lower, bool): if plot_lower: plot_lower = "20%" else: plot_lower = None if isinstance(plot_middle, bool): if plot_middle: plot_middle = "50%" else: plot_middle = None if isinstance(plot_upper, bool): if plot_upper: plot_upper = "80%" else: plot_upper = None if isinstance(plot_aux_middle, bool): if plot_aux_middle: plot_aux_middle = "mean" else: plot_aux_middle = None def _resolve_band_and_name(band_func, arg_name): band_title = None if isinstance(band_func, str): band_func_meta = self.band(band_func, return_meta=True) band_title = band_func_meta["band_title"] band_func = band_func_meta["band_func"] if band_func is not None and not callable(band_func): raise TypeError(f"Argument {arg_name} has wrong type '{type(band_func)}'") return band_func, band_title plot_lower, lower_name = _resolve_band_and_name(plot_lower, "plot_lower") if lower_name is None: lower_name = "Lower (proj)" plot_middle, middle_name = _resolve_band_and_name(plot_middle, "plot_middle") if middle_name is None: middle_name = "Middle (proj)" plot_upper, upper_name = _resolve_band_and_name(plot_upper, "plot_upper") if upper_name is None: upper_name = "Upper (proj)" plot_aux_middle, aux_middle_name = _resolve_band_and_name(plot_aux_middle, "plot_aux_middle") if aux_middle_name is None: aux_middle_name = "Aux middle (proj)" if isinstance(colorize, bool): if colorize: colorize = "median" else: colorize = None if colorize is not None: if isinstance(colorize, str): colorize = colorize.lower().replace(" ", "") if colorize == "median": colorize = lambda x: x.median() elif colorize == "mean": colorize = lambda x: x.mean() elif colorize == "last": colorize = lambda x: x.ffill().iloc[-1] else: raise ValueError(f"Argument colorize has wrong value '{colorize}'") if colorize is not None and not callable(colorize): raise TypeError(f"Argument colorize has wrong type '{type(colorize)}'") if colorize is not None: proj_min = colorize(self.obj - self.obj.iloc[0]).min() proj_max = colorize(self.obj - self.obj.iloc[0]).max() else: proj_min = None proj_max = None if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if len(self.obj.columns) > 0: if plot_projections: # Plot projections for col in range(self.wrapper.shape[1]): proj_sr = self.obj.iloc[:, col].dropna() hovertemplate = f"(%{{x}}, %{{y}})" if not checks.is_default_index(self.wrapper.columns): level_names = [] level_values = [] if isinstance(self.wrapper.columns, pd.MultiIndex): for l in range(self.wrapper.columns.nlevels): level_names.append(self.wrapper.columns.names[l]) level_values.append(self.wrapper.columns.get_level_values(l)[col]) else: level_names.append(self.wrapper.columns.name) level_values.append(self.wrapper.columns[col]) for l in range(len(level_names)): level_name = level_names[l] level_value = level_values[l] if rename_levels is not None: if isinstance(rename_levels, dict): if level_name in rename_levels: level_name = rename_levels[level_name] elif l in rename_levels: level_name = rename_levels[l] else: level_name = rename_levels[l] if level_name is None: level_name = f"Level {l}" hovertemplate += f"
{level_name}: {level_value}" if colorize is not None: proj_color = map_value_to_cmap( colorize(proj_sr - proj_sr.iloc[0]), [ plotting_cfg["color_schema"]["red"], plotting_cfg["color_schema"]["yellow"], plotting_cfg["color_schema"]["green"], ], vmin=proj_min, vcenter=0, vmax=proj_max, ) else: proj_color = plotting_cfg["color_schema"]["gray"] if not plot_bands: proj_opacity = 0.5 else: proj_opacity = 0.1 _projection_trace_kwargs = merge_dicts( dict( name=f"proj ({self.obj.shape[1]})", line=dict(color=proj_color), opacity=proj_opacity, legendgroup="proj", showlegend=col == 0, hovertemplate=hovertemplate, ), projection_trace_kwargs, ) proj_sr.rename(None).vbt.lineplot( trace_kwargs=_projection_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if plot_bands and len(self.obj.columns) > 1: # Calculate bands if plot_lower is not None: lower_band = plot_lower(self.obj).dropna() else: lower_band = None if plot_middle is not None: middle_band = plot_middle(self.obj).dropna() else: middle_band = None if plot_upper is not None: upper_band = plot_upper(self.obj).dropna() else: upper_band = None if plot_aux_middle is not None: aux_middle_band = plot_aux_middle(self.obj).dropna() else: aux_middle_band = None if lower_band is not None: # Plot lower band def_lower_trace_kwargs = dict(name=lower_name) if colorize is not None: lower_color = map_value_to_cmap( colorize(lower_band - lower_band.iloc[0]), [ plotting_cfg["color_schema"]["red"], plotting_cfg["color_schema"]["yellow"], plotting_cfg["color_schema"]["green"], ], vmin=proj_min, vcenter=0, vmax=proj_max, ) def_lower_trace_kwargs["line"] = dict(color=adjust_opacity(lower_color, 0.75)) else: lower_color = plotting_cfg["color_schema"]["gray"] def_lower_trace_kwargs["line"] = dict(color=adjust_opacity(lower_color, 0.5)) lower_band.rename(None).vbt.lineplot( trace_kwargs=merge_dicts(def_lower_trace_kwargs, lower_trace_kwargs), add_trace_kwargs=add_trace_kwargs, fig=fig, ) if middle_band is not None: # Plot middle band def_middle_trace_kwargs = dict(name=middle_name) if colorize is not None: middle_color = map_value_to_cmap( colorize(middle_band - middle_band.iloc[0]), [ plotting_cfg["color_schema"]["red"], plotting_cfg["color_schema"]["yellow"], plotting_cfg["color_schema"]["green"], ], vmin=proj_min, vcenter=0, vmax=proj_max, ) else: middle_color = plotting_cfg["color_schema"]["gray"] def_middle_trace_kwargs["line"] = dict(color=middle_color) if plot_fill and lower_band is not None: def_middle_trace_kwargs["fill"] = "tonexty" def_middle_trace_kwargs["fillcolor"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.25) middle_band.rename(None).vbt.lineplot( trace_kwargs=merge_dicts(def_middle_trace_kwargs, middle_trace_kwargs), add_trace_kwargs=add_trace_kwargs, fig=fig, ) if upper_band is not None: # Plot upper band def_upper_trace_kwargs = dict(name=upper_name) if colorize is not None: upper_color = map_value_to_cmap( colorize(upper_band - upper_band.iloc[0]), [ plotting_cfg["color_schema"]["red"], plotting_cfg["color_schema"]["yellow"], plotting_cfg["color_schema"]["green"], ], vmin=proj_min, vcenter=0, vmax=proj_max, ) def_upper_trace_kwargs["line"] = dict(color=adjust_opacity(upper_color, 0.75)) else: upper_color = plotting_cfg["color_schema"]["gray"] def_upper_trace_kwargs["line"] = dict(color=adjust_opacity(upper_color, 0.5)) if plot_fill and (lower_band is not None or middle_band is not None): def_upper_trace_kwargs["fill"] = "tonexty" def_upper_trace_kwargs["fillcolor"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.25) upper_band.rename(None).vbt.lineplot( trace_kwargs=merge_dicts(def_upper_trace_kwargs, upper_trace_kwargs), add_trace_kwargs=add_trace_kwargs, fig=fig, ) if aux_middle_band is not None: # Plot auxiliary band def_aux_middle_trace_kwargs = dict(name=aux_middle_name) if colorize is not None: aux_middle_color = map_value_to_cmap( colorize(aux_middle_band - aux_middle_band.iloc[0]), [ plotting_cfg["color_schema"]["red"], plotting_cfg["color_schema"]["yellow"], plotting_cfg["color_schema"]["green"], ], vmin=proj_min, vcenter=0, vmax=proj_max, ) else: aux_middle_color = plotting_cfg["color_schema"]["gray"] def_aux_middle_trace_kwargs["line"] = dict(dash="dot", color=aux_middle_color) aux_middle_band.rename(None).vbt.lineplot( trace_kwargs=merge_dicts(def_aux_middle_trace_kwargs, aux_middle_trace_kwargs), add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig
# ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class for analyzing data.""" from vectorbtpro import _typing as tp from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.generic.plots_builder import PlotsBuilderMixin from vectorbtpro.generic.stats_builder import StatsBuilderMixin __all__ = [ "Analyzable", ] class MetaAnalyzable(type(Wrapping), type(StatsBuilderMixin), type(PlotsBuilderMixin)): """Metaclass for `Analyzable`.""" pass AnalyzableT = tp.TypeVar("AnalyzableT", bound="Analyzable") class Analyzable(Wrapping, StatsBuilderMixin, PlotsBuilderMixin, metaclass=MetaAnalyzable): """Class that can be analyzed by computing and plotting attributes of any kind.""" def __init__(self, wrapper: ArrayWrapper, **kwargs) -> None: Wrapping.__init__(self, wrapper, **kwargs) StatsBuilderMixin.__init__(self) PlotsBuilderMixin.__init__(self) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class decorators for generic accessors.""" import inspect from vectorbtpro import _typing as tp from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts, Config from vectorbtpro.utils.parsing import get_func_arg_names __all__ = [] def attach_nb_methods(config: Config) -> tp.ClassWrapper: """Class decorator to attach Numba methods. `config` must contain target method names (keys) and dictionaries (values) with the following keys: * `func`: Function that must be wrapped. The first argument must expect a 2-dim array. * `is_reducing`: Whether the function is reducing. Defaults to False. * `disable_jitted`: Whether to disable the `jitted` option. * `disable_chunked`: Whether to disable the `chunked` option. * `replace_signature`: Whether to replace the target signature with the source signature. Defaults to True. * `wrap_kwargs`: Default keyword arguments for wrapping. Will be merged with the dict supplied by the user. Defaults to `dict(name_or_index=target_name)` for reducing functions. The class must be a subclass of `vectorbtpro.base.wrapping.Wrapping`. """ def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: from vectorbtpro.base.wrapping import Wrapping checks.assert_subclass_of(cls, Wrapping) for target_name, settings in config.items(): func = settings["func"] is_reducing = settings.get("is_reducing", False) disable_jitted = settings.get("disable_jitted", False) disable_chunked = settings.get("disable_chunked", False) replace_signature = settings.get("replace_signature", True) default_wrap_kwargs = settings.get("wrap_kwargs", dict(name_or_index=target_name) if is_reducing else None) def new_method( self, *args, _target_name: str = target_name, _func: tp.Callable = func, _is_reducing: bool = is_reducing, _disable_jitted: bool = disable_jitted, _disable_chunked: bool = disable_chunked, _default_wrap_kwargs: tp.KwargsLike = default_wrap_kwargs, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: args = (self.to_2d_array(),) + args inspect.signature(_func).bind(*args, **kwargs) if not _disable_jitted: _func = jit_reg.resolve_option(_func, jitted) elif jitted is not None: raise ValueError("This method doesn't support jitting") if not _disable_chunked: _func = ch_reg.resolve_option(_func, chunked) elif chunked is not None: raise ValueError("This method doesn't support chunking") a = _func(*args, **kwargs) wrap_kwargs = merge_dicts(_default_wrap_kwargs, wrap_kwargs) if _is_reducing: return self.wrapper.wrap_reduced(a, **wrap_kwargs) return self.wrapper.wrap(a, **wrap_kwargs) if replace_signature: # Replace the function's signature with the original one source_sig = inspect.signature(func) new_method_params = tuple(inspect.signature(new_method).parameters.values()) self_arg = new_method_params[0] jitted_arg = new_method_params[-4] chunked_arg = new_method_params[-3] wrap_kwargs_arg = new_method_params[-2] new_parameters = (self_arg,) + tuple(source_sig.parameters.values())[1:] if not disable_jitted: new_parameters += (jitted_arg,) if not disable_chunked: new_parameters += (chunked_arg,) new_parameters += (wrap_kwargs_arg,) new_method.__signature__ = source_sig.replace(parameters=new_parameters) new_method.__name__ = target_name new_method.__module__ = cls.__module__ new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}" new_method.__doc__ = f"See `{func.__module__ + '.' + func.__name__}`." setattr(cls, target_name, new_method) return cls return wrapper def attach_transform_methods(config: Config) -> tp.ClassWrapper: """Class decorator to add transformation methods. `config` must contain target method names (keys) and dictionaries (values) with the following keys: * `transformer`: Transformer class/object. * `docstring`: Method docstring. * `replace_signature`: Whether to replace the target signature. Defaults to True. The class must be a subclass of `vectorbtpro.generic.accessors.GenericAccessor`. """ def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: from vectorbtpro.generic.accessors import TransformerT checks.assert_subclass_of(cls, "GenericAccessor") for target_name, settings in config.items(): transformer = settings["transformer"] docstring = settings.get("docstring", f"See `{transformer.__name__}`.") replace_signature = settings.get("replace_signature", True) def new_method( self, _target_name: str = target_name, _transformer: tp.Union[tp.Type[TransformerT], TransformerT] = transformer, **kwargs, ) -> tp.SeriesFrame: if inspect.isclass(_transformer): arg_names = get_func_arg_names(_transformer.__init__) transformer_kwargs = dict() for arg_name in arg_names: if arg_name in kwargs: transformer_kwargs[arg_name] = kwargs.pop(arg_name) return self.transform(_transformer(**transformer_kwargs), **kwargs) return self.transform(_transformer, **kwargs) if replace_signature: source_sig = inspect.signature(transformer.__init__) new_method_params = tuple(inspect.signature(new_method).parameters.values()) if inspect.isclass(transformer): transformer_params = tuple(source_sig.parameters.values()) source_sig = inspect.Signature( (new_method_params[0],) + transformer_params[1:] + (new_method_params[-1],), ) new_method.__signature__ = source_sig else: source_sig = inspect.Signature((new_method_params[0],) + (new_method_params[-1],)) new_method.__signature__ = source_sig new_method.__name__ = target_name new_method.__module__ = cls.__module__ new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}" new_method.__doc__ = docstring setattr(cls, target_name, new_method) return cls return wrapper # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with drawdown records. Drawdown records capture information on drawdowns. Since drawdowns are ranges, they subclass `vectorbtpro.generic.ranges.Ranges`. !!! warning `Drawdowns` return both recovered AND active drawdowns, which may skew your performance results. To only consider recovered drawdowns, you should explicitly query `status_recovered` attribute. Using `Drawdowns.from_price`, you can generate drawdown records for any time series and analyze them right away. ```pycon >>> from vectorbtpro import * >>> price = vbt.YFData.pull( ... "BTC-USD", ... start="2019-10 UTC", ... end="2020-01 UTC" ... ).get('Close') ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> price = price.rename(None) >>> drawdowns = vbt.Drawdowns.from_price(price, wrapper_kwargs=dict(freq='d')) >>> drawdowns.readable Drawdown Id Column Start Index Valley Index \\ 0 0 0 2019-10-02 00:00:00+00:00 2019-10-06 00:00:00+00:00 1 1 0 2019-10-09 00:00:00+00:00 2019-10-24 00:00:00+00:00 2 2 0 2019-10-27 00:00:00+00:00 2019-12-17 00:00:00+00:00 End Index Peak Value Valley Value End Value Status 0 2019-10-09 00:00:00+00:00 8393.041992 7988.155762 8595.740234 Recovered 1 2019-10-25 00:00:00+00:00 8595.740234 7493.488770 8660.700195 Recovered 2 2019-12-31 00:00:00+00:00 9551.714844 6640.515137 7193.599121 Active >>> drawdowns.duration.max(wrap_kwargs=dict(to_timedelta=True)) Timedelta('66 days 00:00:00') ``` ## From accessors Moreover, all generic accessors have a property `drawdowns` and a method `get_drawdowns`: ```pycon >>> # vectorbtpro.generic.accessors.GenericAccessor.drawdowns.coverage >>> price.vbt.drawdowns.coverage 0.967391304347826 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Drawdowns.metrics`. ```pycon >>> df = pd.DataFrame({ ... 'a': [1, 2, 1, 3, 2], ... 'b': [2, 3, 1, 2, 1] ... }) >>> drawdowns = df.vbt(freq='d').drawdowns >>> drawdowns['a'].stats() Start 0 End 4 Period 5 days 00:00:00 Coverage [%] 80.0 Total Records 2 Total Recovered Drawdowns 1 Total Active Drawdowns 1 Active Drawdown [%] 33.333333 Active Duration 2 days 00:00:00 Active Recovery [%] 0.0 Active Recovery Return [%] 0.0 Active Recovery Duration 0 days 00:00:00 Max Drawdown [%] 50.0 Avg Drawdown [%] 50.0 Max Drawdown Duration 2 days 00:00:00 Avg Drawdown Duration 2 days 00:00:00 Max Recovery Return [%] 200.0 Avg Recovery Return [%] 200.0 Max Recovery Duration 1 days 00:00:00 Avg Recovery Duration 1 days 00:00:00 Avg Recovery Duration Ratio 1.0 Name: a, dtype: object ``` By default, the metrics `max_dd`, `avg_dd`, `max_dd_duration`, and `avg_dd_duration` do not include active drawdowns. To change that, pass `incl_active=True`: ```pycon >>> drawdowns['a'].stats(settings=dict(incl_active=True)) Start 0 End 4 Period 5 days 00:00:00 Coverage [%] 80.0 Total Records 2 Total Recovered Drawdowns 1 Total Active Drawdowns 1 Active Drawdown [%] 33.333333 Active Duration 2 days 00:00:00 Active Recovery [%] 0.0 Active Recovery Return [%] 0.0 Active Recovery Duration 0 days 00:00:00 Max Drawdown [%] 50.0 Avg Drawdown [%] 41.666667 Max Drawdown Duration 2 days 00:00:00 Avg Drawdown Duration 2 days 00:00:00 Max Recovery Return [%] 200.0 Avg Recovery Return [%] 200.0 Max Recovery Duration 1 days 00:00:00 Avg Recovery Duration 1 days 00:00:00 Avg Recovery Duration Ratio 1.0 Name: a, dtype: object ``` `Drawdowns.stats` also supports (re-)grouping: ```pycon >>> drawdowns['a'].stats(group_by=True) UserWarning: Metric 'active_dd' does not support grouped data UserWarning: Metric 'active_duration' does not support grouped data UserWarning: Metric 'active_recovery' does not support grouped data UserWarning: Metric 'active_recovery_return' does not support grouped data UserWarning: Metric 'active_recovery_duration' does not support grouped data Start 0 End 4 Period 5 days 00:00:00 Coverage [%] 80.0 Total Records 2 Total Recovered Drawdowns 1 Total Active Drawdowns 1 Max Drawdown [%] 50.0 Avg Drawdown [%] 50.0 Max Drawdown Duration 2 days 00:00:00 Avg Drawdown Duration 2 days 00:00:00 Max Recovery Return [%] 200.0 Avg Recovery Return [%] 200.0 Max Recovery Duration 1 days 00:00:00 Avg Recovery Duration 1 days 00:00:00 Avg Recovery Duration Ratio 1.0 Name: group, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Drawdowns.subplots`. `Drawdowns` class has a single subplot based on `Drawdowns.plot`: ```pycon >>> drawdowns['a'].plots().show() ``` ![](/assets/images/api/drawdowns_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/drawdowns_plots.dark.svg#only-dark){: .iimg loading=lazy } """ import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_1d_array, to_2d_array from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb from vectorbtpro.generic.enums import DrawdownStatus, drawdown_dt, range_dt from vectorbtpro.generic.ranges import Ranges from vectorbtpro.records.decorators import override_field_config, attach_fields, attach_shortcut_properties from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils.colors import adjust_lightness from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, ReadonlyConfig, HybridConfig from vectorbtpro.utils.template import RepEval, RepFunc __all__ = [ "Drawdowns", ] __pdoc__ = {} dd_field_config = ReadonlyConfig( dict( dtype=drawdown_dt, settings=dict( id=dict(title="Drawdown Id"), valley_idx=dict(title="Valley Index", mapping="index"), start_val=dict( title="Start Value", ), valley_val=dict( title="Valley Value", ), end_val=dict( title="End Value", ), status=dict(mapping=DrawdownStatus), ), ) ) """_""" __pdoc__[ "dd_field_config" ] = f"""Field config for `Drawdowns`. ```python {dd_field_config.prettify()} ``` """ dd_attach_field_config = ReadonlyConfig(dict(status=dict(attach_filters=True))) """_""" __pdoc__[ "dd_attach_field_config" ] = f"""Config of fields to be attached to `Drawdowns`. ```python {dd_attach_field_config.prettify()} ``` """ dd_shortcut_config = ReadonlyConfig( dict( ranges=dict(), decline_ranges=dict(), recovery_ranges=dict(), drawdown=dict(obj_type="mapped_array"), avg_drawdown=dict(obj_type="red_array"), max_drawdown=dict(obj_type="red_array"), recovery_return=dict(obj_type="mapped_array"), avg_recovery_return=dict(obj_type="red_array"), max_recovery_return=dict(obj_type="red_array"), decline_duration=dict(obj_type="mapped_array"), recovery_duration=dict(obj_type="mapped_array"), recovery_duration_ratio=dict(obj_type="mapped_array"), active_drawdown=dict(obj_type="red_array"), active_duration=dict(obj_type="red_array"), active_recovery=dict(obj_type="red_array"), active_recovery_return=dict(obj_type="red_array"), active_recovery_duration=dict(obj_type="red_array"), ) ) """_""" __pdoc__[ "dd_shortcut_config" ] = f"""Config of shortcut properties to be attached to `Drawdowns`. ```python {dd_shortcut_config.prettify()} ``` """ DrawdownsT = tp.TypeVar("DrawdownsT", bound="Drawdowns") @attach_shortcut_properties(dd_shortcut_config) @attach_fields(dd_attach_field_config) @override_field_config(dd_field_config) class Drawdowns(Ranges): """Extends `vectorbtpro.generic.ranges.Ranges` for working with drawdown records. Requires `records_arr` to have all fields defined in `vectorbtpro.generic.enums.drawdown_dt`.""" @property def field_config(self) -> Config: return self._field_config @classmethod def from_price( cls: tp.Type[DrawdownsT], close: tp.ArrayLike, *, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, attach_data: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> DrawdownsT: """Build `Drawdowns` from price. `**kwargs` will be passed to `Drawdowns.__init__`.""" if wrapper_kwargs is None: wrapper_kwargs = {} close_arr = to_2d_array(close) open_arr = to_2d_array(open) if open is not None else None high_arr = to_2d_array(high) if high is not None else None low_arr = to_2d_array(low) if low is not None else None func = jit_reg.resolve_option(nb.get_drawdowns_nb, jitted) func = ch_reg.resolve_option(func, chunked) records_arr = func( open=open_arr, high=high_arr, low=low_arr, close=close_arr, sim_start=sim_start, sim_end=sim_end, ) if wrapper is None: wrapper = ArrayWrapper.from_obj(close, **resolve_dict(wrapper_kwargs)) elif len(wrapper_kwargs) > 0: wrapper = wrapper.replace(**wrapper_kwargs) return cls( wrapper, records_arr, open=open if attach_data else None, high=high if attach_data else None, low=low if attach_data else None, close=close if attach_data else None, **kwargs, ) def get_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges` for peak-to-end ranges.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("start_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("end_idx").copy() new_records_arr["status"][:] = self.get_field_arr("status").copy() return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) def get_decline_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges` for peak-to-valley ranges.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("start_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("valley_idx").copy() new_records_arr["status"][:] = self.get_field_arr("status").copy() return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) def get_recovery_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges` for valley-to-end ranges.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("valley_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("end_idx").copy() new_records_arr["status"][:] = self.get_field_arr("status").copy() return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) # ############# Drawdown ############# # def get_drawdown(self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs) -> MappedArray: """See `vectorbtpro.generic.nb.records.dd_drawdown_nb`. Takes into account both recovered and active drawdowns.""" func = jit_reg.resolve_option(nb.dd_drawdown_nb, jitted) func = ch_reg.resolve_option(func, chunked) drawdown = func(self.get_field_arr("start_val"), self.get_field_arr("valley_val")) return self.map_array(drawdown, **kwargs) def get_avg_drawdown( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get average drawdown (ADD). Based on `Drawdowns.drawdown`.""" wrap_kwargs = merge_dicts(dict(name_or_index="avg_drawdown"), wrap_kwargs) return self.drawdown.mean(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs) def get_max_drawdown( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get maximum drawdown (MDD). Based on `Drawdowns.drawdown`.""" wrap_kwargs = merge_dicts(dict(name_or_index="max_drawdown"), wrap_kwargs) return self.drawdown.min(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs) # ############# Recovery ############# # def get_recovery_return( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """See `vectorbtpro.generic.nb.records.dd_recovery_return_nb`. Takes into account both recovered and active drawdowns.""" func = jit_reg.resolve_option(nb.dd_recovery_return_nb, jitted) func = ch_reg.resolve_option(func, chunked) recovery_return = func(self.get_field_arr("valley_val"), self.get_field_arr("end_val")) return self.map_array(recovery_return, **kwargs) def get_avg_recovery_return( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get average recovery return. Based on `Drawdowns.recovery_return`.""" wrap_kwargs = merge_dicts(dict(name_or_index="avg_recovery_return"), wrap_kwargs) return self.recovery_return.mean( group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_max_recovery_return( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get maximum recovery return. Based on `Drawdowns.recovery_return`.""" wrap_kwargs = merge_dicts(dict(name_or_index="max_recovery_return"), wrap_kwargs) return self.recovery_return.max( group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) # ############# Duration ############# # def get_decline_duration( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """See `vectorbtpro.generic.nb.records.dd_decline_duration_nb`. Takes into account both recovered and active drawdowns.""" func = jit_reg.resolve_option(nb.dd_decline_duration_nb, jitted) func = ch_reg.resolve_option(func, chunked) decline_duration = func(self.get_field_arr("start_idx"), self.get_field_arr("valley_idx")) return self.map_array(decline_duration, **kwargs) def get_recovery_duration( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """See `vectorbtpro.generic.nb.records.dd_recovery_duration_nb`. A value higher than 1 means the recovery was slower than the decline. Takes into account both recovered and active drawdowns.""" func = jit_reg.resolve_option(nb.dd_recovery_duration_nb, jitted) func = ch_reg.resolve_option(func, chunked) recovery_duration = func(self.get_field_arr("valley_idx"), self.get_field_arr("end_idx")) return self.map_array(recovery_duration, **kwargs) def get_recovery_duration_ratio( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """See `vectorbtpro.generic.nb.records.dd_recovery_duration_ratio_nb`. Takes into account both recovered and active drawdowns.""" func = jit_reg.resolve_option(nb.dd_recovery_duration_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) recovery_duration_ratio = func( self.get_field_arr("start_idx"), self.get_field_arr("valley_idx"), self.get_field_arr("end_idx"), ) return self.map_array(recovery_duration_ratio, **kwargs) # ############# Status: Active ############# # def get_active_drawdown( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get drawdown of the last active drawdown only. Does not support grouping.""" if self.wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping is not supported by this method") wrap_kwargs = merge_dicts(dict(name_or_index="active_drawdown"), wrap_kwargs) active = self.status_active curr_end_val = active.end_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked) curr_start_val = active.start_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked) curr_drawdown = (curr_end_val - curr_start_val) / curr_start_val return self.wrapper.wrap_reduced(curr_drawdown, group_by=group_by, **wrap_kwargs) def get_active_duration( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get duration of the last active drawdown only. Does not support grouping.""" if self.wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping is not supported by this method") wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="active_duration"), wrap_kwargs) return self.status_active.duration.nth( -1, jitted=jitted, chunked=chunked, group_by=group_by, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_active_recovery( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get recovery of the last active drawdown only. Does not support grouping.""" if self.wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping is not supported by this method") wrap_kwargs = merge_dicts(dict(name_or_index="active_recovery"), wrap_kwargs) active = self.status_active curr_start_val = active.start_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked) curr_end_val = active.end_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked) curr_valley_val = active.valley_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked) curr_recovery = (curr_end_val - curr_valley_val) / (curr_start_val - curr_valley_val) return self.wrapper.wrap_reduced(curr_recovery, group_by=group_by, **wrap_kwargs) def get_active_recovery_return( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get recovery return of the last active drawdown only. Does not support grouping.""" if self.wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping is not supported by this method") wrap_kwargs = merge_dicts(dict(name_or_index="active_recovery_return"), wrap_kwargs) return self.status_active.recovery_return.nth( -1, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_active_recovery_duration( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get recovery duration of the last active drawdown only. Does not support grouping.""" if self.wrapper.grouper.is_grouped(group_by=group_by): raise ValueError("Grouping is not supported by this method") wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="active_recovery_duration"), wrap_kwargs) return self.status_active.recovery_duration.nth( -1, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Drawdowns.stats`. Merges `vectorbtpro.generic.ranges.Ranges.stats_defaults` and `stats` from `vectorbtpro._settings.drawdowns`.""" from vectorbtpro._settings import settings drawdowns_stats_cfg = settings["drawdowns"]["stats"] return merge_dicts(Ranges.stats_defaults.__get__(self), drawdowns_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), coverage=dict( title="Coverage [%]", calc_func="coverage", post_calc_func=lambda self, out, settings: out * 100, tags=["ranges", "duration"], ), total_records=dict(title="Total Records", calc_func="count", tags="records"), total_recovered=dict( title="Total Recovered Drawdowns", calc_func="status_recovered.count", tags="drawdowns", ), total_active=dict(title="Total Active Drawdowns", calc_func="status_active.count", tags="drawdowns"), active_dd=dict( title="Active Drawdown [%]", calc_func="active_drawdown", post_calc_func=lambda self, out, settings: -out * 100, check_is_not_grouped=True, tags=["drawdowns", "active"], ), active_duration=dict( title="Active Duration", calc_func="active_duration", fill_wrap_kwargs=True, check_is_not_grouped=True, tags=["drawdowns", "active", "duration"], ), active_recovery=dict( title="Active Recovery [%]", calc_func="active_recovery", post_calc_func=lambda self, out, settings: out * 100, check_is_not_grouped=True, tags=["drawdowns", "active"], ), active_recovery_return=dict( title="Active Recovery Return [%]", calc_func="active_recovery_return", post_calc_func=lambda self, out, settings: out * 100, check_is_not_grouped=True, tags=["drawdowns", "active"], ), active_recovery_duration=dict( title="Active Recovery Duration", calc_func="active_recovery_duration", fill_wrap_kwargs=True, check_is_not_grouped=True, tags=["drawdowns", "active", "duration"], ), max_dd=dict( title="Max Drawdown [%]", calc_func=RepEval("'max_drawdown' if incl_active else 'status_recovered.get_max_drawdown'"), post_calc_func=lambda self, out, settings: -out * 100, tags=RepEval("['drawdowns'] if incl_active else ['drawdowns', 'recovered']"), ), avg_dd=dict( title="Avg Drawdown [%]", calc_func=RepEval("'avg_drawdown' if incl_active else 'status_recovered.get_avg_drawdown'"), post_calc_func=lambda self, out, settings: -out * 100, tags=RepEval("['drawdowns'] if incl_active else ['drawdowns', 'recovered']"), ), max_dd_duration=dict( title="Max Drawdown Duration", calc_func=RepEval("'max_duration' if incl_active else 'status_recovered.get_max_duration'"), fill_wrap_kwargs=True, tags=RepEval("['drawdowns', 'duration'] if incl_active else ['drawdowns', 'recovered', 'duration']"), ), avg_dd_duration=dict( title="Avg Drawdown Duration", calc_func=RepEval("'avg_duration' if incl_active else 'status_recovered.get_avg_duration'"), fill_wrap_kwargs=True, tags=RepEval("['drawdowns', 'duration'] if incl_active else ['drawdowns', 'recovered', 'duration']"), ), max_return=dict( title="Max Recovery Return [%]", calc_func="status_recovered.recovery_return.max", post_calc_func=lambda self, out, settings: out * 100, tags=["drawdowns", "recovered"], ), avg_return=dict( title="Avg Recovery Return [%]", calc_func="status_recovered.recovery_return.mean", post_calc_func=lambda self, out, settings: out * 100, tags=["drawdowns", "recovered"], ), max_recovery_duration=dict( title="Max Recovery Duration", calc_func="status_recovered.recovery_duration.max", apply_to_timedelta=True, tags=["drawdowns", "recovered", "duration"], ), avg_recovery_duration=dict( title="Avg Recovery Duration", calc_func="status_recovered.recovery_duration.mean", apply_to_timedelta=True, tags=["drawdowns", "recovered", "duration"], ), recovery_duration_ratio=dict( title="Avg Recovery Duration Ratio", calc_func="status_recovered.recovery_duration_ratio.mean", tags=["drawdowns", "recovered"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, column: tp.Optional[tp.Label] = None, top_n: tp.Optional[int] = 5, plot_ohlc: bool = True, plot_close: bool = True, plot_markers: bool = True, plot_zones: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, peak_trace_kwargs: tp.KwargsLike = None, valley_trace_kwargs: tp.KwargsLike = None, recovery_trace_kwargs: tp.KwargsLike = None, active_trace_kwargs: tp.KwargsLike = None, decline_shape_kwargs: tp.KwargsLike = None, recovery_shape_kwargs: tp.KwargsLike = None, active_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot drawdowns. Args: column (str): Name of the column to plot. top_n (int): Filter top N drawdown records by maximum drawdown. plot_ohlc (bool): Whether to plot OHLC. plot_close (bool): Whether to plot close. plot_markers (bool): Whether to plot markers. plot_zones (bool): Whether to plot zones. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Drawdowns.close`. peak_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for peak values. valley_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for valley values. recovery_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for recovery values. active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for active recovery values. decline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for decline zones. recovery_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for recovery zones. active_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for active recovery zones. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=8) >>> price = pd.Series([1, 2, 1, 2, 3, 2, 1, 2], index=index) >>> vbt.Drawdowns.from_price(price, wrapper_kwargs=dict(freq='1 day')).plot().show() ``` ![](/assets/images/api/drawdowns_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/drawdowns_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure, get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if top_n is not None: # Drawdowns is negative, thus top_n becomes bottom_n self_col = self_col.apply_mask(self_col.drawdown.bottom_n_mask(top_n)) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if peak_trace_kwargs is None: peak_trace_kwargs = {} if valley_trace_kwargs is None: valley_trace_kwargs = {} if recovery_trace_kwargs is None: recovery_trace_kwargs = {} if active_trace_kwargs is None: active_trace_kwargs = {} if decline_shape_kwargs is None: decline_shape_kwargs = {} if recovery_shape_kwargs is None: recovery_shape_kwargs = {} if active_shape_kwargs is None: active_shape_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) y_domain = get_domain(yref, fig) plotting_ohlc = False if ( plot_ohlc and self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): plotting_ohlc = True ohlc_df = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc_df.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close and self_col._close is not None: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Extract information id_ = self_col.get_field_arr("id") start_idx = self_col.get_map_field_to_index("start_idx") if not plotting_ohlc and self_col._close is not None: start_val = self_col.close.loc[start_idx] else: start_val = self_col.get_field_arr("start_val") valley_idx = self_col.get_map_field_to_index("valley_idx") if not plotting_ohlc and self_col._close is not None: valley_val = self_col.close.loc[valley_idx] else: valley_val = self_col.get_field_arr("valley_val") end_idx = self_col.get_map_field_to_index("end_idx") if not plotting_ohlc and self_col._close is not None: end_val = self_col.close.loc[end_idx] else: end_val = self_col.get_field_arr("end_val") drawdown = self_col.drawdown.values recovery_return = self_col.recovery_return.values status = self_col.get_field_arr("status") decline_duration = to_1d_array( self_col.wrapper.arr_to_timedelta( self_col.decline_duration.values, to_pd=True, silence_warnings=True ).astype(str) ) recovery_duration = to_1d_array( self_col.wrapper.arr_to_timedelta( self_col.recovery_duration.values, to_pd=True, silence_warnings=True ).astype(str) ) duration = to_1d_array( self_col.wrapper.arr_to_timedelta(self_col.duration.values, to_pd=True, silence_warnings=True).astype( str ) ) # Peak and recovery at same time -> recovery wins peak_mask = (start_val != np.roll(end_val, 1)) | (start_idx != np.roll(end_idx, 1)) if peak_mask.any(): if plot_markers: # Plot peak markers peak_customdata, peak_hovertemplate = self_col.prepare_customdata( incl_fields=["id", "start_idx", "start_val"], mask=peak_mask ) _peak_trace_kwargs = merge_dicts( dict( x=start_idx[peak_mask], y=start_val[peak_mask], mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["blue"], size=7, line=dict( width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["blue"]) ), ), name="Peak", customdata=peak_customdata, hovertemplate=peak_hovertemplate, ), peak_trace_kwargs, ) peak_scatter = go.Scatter(**_peak_trace_kwargs) fig.add_trace(peak_scatter, **add_trace_kwargs) recovered_mask = status == DrawdownStatus.Recovered if recovered_mask.any(): if plot_markers: # Plot valley markers valley_customdata, valley_hovertemplate = self_col.prepare_customdata( incl_fields=["id", "valley_idx", "valley_val"], append_info=[ (drawdown, "Drawdown", "$title: %{customdata[$index]:,%}"), (decline_duration, "Decline duration"), ], mask=recovered_mask, ) _valley_trace_kwargs = merge_dicts( dict( x=valley_idx[recovered_mask], y=valley_val[recovered_mask], mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["red"], size=7, line=dict( width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["red"]) ), ), name="Valley", customdata=valley_customdata, hovertemplate=valley_hovertemplate, ), valley_trace_kwargs, ) valley_scatter = go.Scatter(**_valley_trace_kwargs) fig.add_trace(valley_scatter, **add_trace_kwargs) if plot_markers: # Plot recovery markers recovery_customdata, recovery_hovertemplate = self_col.prepare_customdata( incl_fields=["id", "end_idx", "end_val"], append_info=[ (drawdown, "Drawdown", "$title: %{customdata[$index]:,%}"), (duration, "Drawdown duration"), (recovery_return, "Recovery return", "$title: %{customdata[$index]:,%}"), (recovery_duration, "Recovery duration"), ], mask=recovered_mask, ) _recovery_trace_kwargs = merge_dicts( dict( x=end_idx[recovered_mask], y=end_val[recovered_mask], mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["green"], size=7, line=dict( width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["green"]) ), ), name="Recovery/Peak", customdata=recovery_customdata, hovertemplate=recovery_hovertemplate, ), recovery_trace_kwargs, ) recovery_scatter = go.Scatter(**_recovery_trace_kwargs) fig.add_trace(recovery_scatter, **add_trace_kwargs) active_mask = status == DrawdownStatus.Active if active_mask.any(): if plot_markers: # Plot active markers active_customdata, active_hovertemplate = self_col.prepare_customdata( incl_fields=["id"], append_info=[ (drawdown, "Drawdown", "$title: %{customdata[$index]:,%}"), (duration, "Drawdown duration"), ], mask=active_mask, ) _active_trace_kwargs = merge_dicts( dict( x=end_idx[active_mask], y=end_val[active_mask], mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["orange"], size=7, line=dict( width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["orange"]) ), ), name="Active", customdata=active_customdata, hovertemplate=active_hovertemplate, ), active_trace_kwargs, ) active_scatter = go.Scatter(**_active_trace_kwargs) fig.add_trace(active_scatter, **add_trace_kwargs) if plot_zones: # Plot drawdown zones self_col.status_recovered.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict( x0=RepFunc(lambda i: start_idx[recovered_mask][i]), x1=RepFunc(lambda i: valley_idx[recovered_mask][i]), fillcolor=plotting_cfg["contrast_color_schema"]["red"], ), decline_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) # Plot recovery zones self_col.status_recovered.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict( x0=RepFunc(lambda i: valley_idx[recovered_mask][i]), x1=RepFunc(lambda i: end_idx[recovered_mask][i]), fillcolor=plotting_cfg["contrast_color_schema"]["green"], ), recovery_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) # Plot active drawdown zones self_col.status_active.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict( x0=RepFunc(lambda i: start_idx[active_mask][i]), x1=RepFunc(lambda i: end_idx[active_mask][i]), fillcolor=plotting_cfg["contrast_color_schema"]["orange"], ), active_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Drawdowns.plots`. Merges `vectorbtpro.generic.ranges.Ranges.plots_defaults` and `plots` from `vectorbtpro._settings.drawdowns`.""" from vectorbtpro._settings import settings drawdowns_plots_cfg = settings["drawdowns"]["plots"] return merge_dicts(Ranges.plots_defaults.__get__(self), drawdowns_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( title="Drawdowns", check_is_not_grouped=True, plot_func="plot", tags="drawdowns", ) ), ) @property def subplots(self) -> Config: return self._subplots Drawdowns.override_field_config_doc(__pdoc__) Drawdowns.override_metrics_doc(__pdoc__) Drawdowns.override_subplots_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for generic data. Defines enums and other schemas for `vectorbtpro.generic`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.utils.formatting import prettify __pdoc__all__ = __all__ = [ "BarZone", "WType", "RangeStatus", "InterpMode", "RescaleMode", "ErrorType", "DistanceMeasure", "OverlapMode", "DrawdownStatus", "range_dt", "pattern_range_dt", "drawdown_dt", "RollSumAIS", "RollSumAOS", "RollProdAIS", "RollProdAOS", "RollMeanAIS", "RollMeanAOS", "RollStdAIS", "RollStdAOS", "RollZScoreAIS", "RollZScoreAOS", "WMMeanAIS", "WMMeanAOS", "EWMMeanAIS", "EWMMeanAOS", "EWMStdAIS", "EWMStdAOS", "VidyaAIS", "VidyaAOS", "RollCovAIS", "RollCovAOS", "RollCorrAIS", "RollCorrAOS", "RollOLSAIS", "RollOLSAOS", ] __pdoc__ = {} # ############# Enums ############# # class BarZoneT(tp.NamedTuple): Open: int = 0 Middle: int = 1 Close: int = 2 BarZone = BarZoneT() """_""" __pdoc__[ "BarZone" ] = f"""Bar zone. ```python {prettify(BarZone)} ``` """ class WTypeT(tp.NamedTuple): Simple: int = 0 Weighted: int = 1 Exp: int = 2 Wilder: int = 3 Vidya: int = 4 WType = WTypeT() """_""" __pdoc__[ "WType" ] = f"""Rolling window type. ```python {prettify(WType)} ``` """ class RangeStatusT(tp.NamedTuple): Open: int = 0 Closed: int = 1 RangeStatus = RangeStatusT() """_""" __pdoc__[ "RangeStatus" ] = f"""Range status. ```python {prettify(RangeStatus)} ``` """ class InterpModeT(tp.NamedTuple): Linear: int = 0 Nearest: int = 1 Discrete: int = 2 Mixed: int = 3 InterpMode = InterpModeT() """_""" __pdoc__[ "InterpMode" ] = f"""Interpolation mode. ```python {prettify(InterpMode)} ``` Attributes: Line: Linear interpolation. For example: `[1.0, 2.0, 3.0]` -> `[1.0, 1.5, 2.0, 2.5, 3.0]` Nearest: Nearest-neighbor interpolation. For example: `[1.0, 2.0, 3.0]` -> `[1.0, 1.0, 2.0, 3.0, 3.0]` Discrete: Discrete interpolation. For example: `[1.0, 2.0, 3.0]` -> `[1.0, np.nan, 2.0, np.nan, 3.0]` Mixed: Mixed interpolation. For example: `[1.0, 2.0, 3.0]` -> `[1.0, 1.5, 2.0, 2.5, 3.0]` """ class RescaleModeT(tp.NamedTuple): MinMax: int = 0 Rebase: int = 1 Disable: int = 2 RescaleMode = RescaleModeT() """_""" __pdoc__[ "RescaleMode" ] = f"""Rescaling mode. ```python {prettify(RescaleMode)} ``` Attributes: MinMax: Array is rescaled from its min-max range to the min-max range of another array. For example: `[3.0, 2.0, 1.0]` to `[10, 11, 12]` -> `[12.0, 11.0, 10.0]` Use this to search for patterns irrespective of their vertical scale. Rebase: Array is rebased to the first value in another array. For example: `[3.0, 2.0, 1.0]` to `[10, 11, 12]` -> `[10.0, 6.6, 3.3]` Use this to search for percentage changes. Disable: Disable any rescaling. For example: `[3.0, 2.0, 1.0]` to `[10, 11, 12]` -> `[3.0, 2.0, 1.0]` Use this to search for particular numbers. """ class ErrorTypeT(tp.NamedTuple): Absolute: int = 0 Relative: int = 1 ErrorType = ErrorTypeT() """_""" __pdoc__[ "ErrorType" ] = f"""Error type. ```python {prettify(ErrorType)} ``` Attributes: Absolute: Absolute error, that is, `x1 - x0`. Relative: Relative error, that is, `(x1 - x0) / x0`. """ class DistanceMeasureT(tp.NamedTuple): MAE: int = 0 MSE: int = 1 RMSE: int = 2 DistanceMeasure = DistanceMeasureT() """_""" __pdoc__[ "DistanceMeasure" ] = f"""Distance measure. ```python {prettify(DistanceMeasure)} ``` Attributes: MAE: Mean absolute error. MSE: Mean squared error. RMSE: Root mean squared error. """ class OverlapModeT(tp.NamedTuple): AllowAll: int = -2 Allow: int = -1 Disallow: int = 0 OverlapMode = OverlapModeT() """_""" __pdoc__[ "OverlapMode" ] = f"""Overlapping mode. ```python {prettify(OverlapMode)} ``` Attributes: AllowAll: Allow any overlapping ranges, even if they start at the same row. Allow: Allow overlapping ranges, but only if they do not start at the same row. Disallow: Disallow any overlapping ranges. Any other positive number will check whether the intersection of each two consecutive ranges is bigger than that number of rows, and if so, the range with the highest similarity will be selected. """ class DrawdownStatusT(tp.NamedTuple): Active: int = 0 Recovered: int = 1 DrawdownStatus = DrawdownStatusT() """_""" __pdoc__[ "DrawdownStatus" ] = f"""Drawdown status. ```python {prettify(DrawdownStatus)} ``` """ # ############# Records ############# # range_dt = np.dtype( [ ("id", int_), ("col", int_), ("start_idx", int_), ("end_idx", int_), ("status", int_), ], align=True, ) """_""" __pdoc__[ "range_dt" ] = f"""`np.dtype` of range records. ```python {prettify(range_dt)} ``` """ pattern_range_dt = np.dtype( [ ("id", int_), ("col", int_), ("start_idx", int_), ("end_idx", int_), ("status", int_), ("similarity", float_), ], align=True, ) """_""" __pdoc__[ "pattern_range_dt" ] = f"""`np.dtype` of pattern range records. ```python {prettify(pattern_range_dt)} ``` """ drawdown_dt = np.dtype( [ ("id", int_), ("col", int_), ("start_idx", int_), ("valley_idx", int_), ("end_idx", int_), ("start_val", float_), ("valley_val", float_), ("end_val", float_), ("status", int_), ], align=True, ) """_""" __pdoc__[ "drawdown_dt" ] = f"""`np.dtype` of drawdown records. ```python {prettify(drawdown_dt)} ``` """ # ############# States ############# # class RollSumAIS(tp.NamedTuple): i: int value: float pre_window_value: float cumsum: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "RollSumAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_sum_acc_nb`.""" class RollSumAOS(tp.NamedTuple): cumsum: float nancnt: int window_len: int value: float __pdoc__[ "RollSumAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_sum_acc_nb`.""" class RollProdAIS(tp.NamedTuple): i: int value: float pre_window_value: float cumprod: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "RollProdAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_prod_acc_nb`.""" class RollProdAOS(tp.NamedTuple): cumprod: float nancnt: int window_len: int value: float __pdoc__[ "RollProdAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_prod_acc_nb`.""" class RollMeanAIS(tp.NamedTuple): i: int value: float pre_window_value: float cumsum: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "RollMeanAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_mean_acc_nb`.""" class RollMeanAOS(tp.NamedTuple): cumsum: float nancnt: int window_len: int value: float __pdoc__[ "RollMeanAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_mean_acc_nb`.""" class RollStdAIS(tp.NamedTuple): i: int value: float pre_window_value: float cumsum: float cumsum_sq: float nancnt: int window: int minp: tp.Optional[int] ddof: int __pdoc__[ "RollStdAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_std_acc_nb`.""" class RollStdAOS(tp.NamedTuple): cumsum: float cumsum_sq: float nancnt: int window_len: int value: float __pdoc__[ "RollStdAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_std_acc_nb`.""" class RollZScoreAIS(tp.NamedTuple): i: int value: float pre_window_value: float cumsum: float cumsum_sq: float nancnt: int window: int minp: tp.Optional[int] ddof: int __pdoc__[ "RollZScoreAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_zscore_acc_nb`.""" class RollZScoreAOS(tp.NamedTuple): cumsum: float cumsum_sq: float nancnt: int window_len: int value: float __pdoc__[ "RollZScoreAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_zscore_acc_nb`.""" class WMMeanAIS(tp.NamedTuple): i: int value: float pre_window_value: float cumsum: float wcumsum: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "WMMeanAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.wm_mean_acc_nb`.""" class WMMeanAOS(tp.NamedTuple): cumsum: float wcumsum: float nancnt: int window_len: int value: float __pdoc__[ "WMMeanAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.wm_mean_acc_nb`.""" class EWMMeanAIS(tp.NamedTuple): i: int value: float old_wt: float weighted_avg: float nobs: int alpha: float minp: tp.Optional[int] adjust: bool __pdoc__[ "EWMMeanAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.ewm_mean_acc_nb`. To get `alpha`, use one of the following: * `vectorbtpro.generic.nb.rolling.alpha_from_com_nb` * `vectorbtpro.generic.nb.rolling.alpha_from_span_nb` * `vectorbtpro.generic.nb.rolling.alpha_from_halflife_nb` * `vectorbtpro.generic.nb.rolling.alpha_from_wilder_nb`""" class EWMMeanAOS(tp.NamedTuple): old_wt: float weighted_avg: float nobs: int value: float __pdoc__[ "EWMMeanAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.ewm_mean_acc_nb`.""" class EWMStdAIS(tp.NamedTuple): i: int value: float mean_x: float mean_y: float cov: float sum_wt: float sum_wt2: float old_wt: float nobs: int alpha: float minp: tp.Optional[int] adjust: bool __pdoc__[ "EWMStdAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.ewm_std_acc_nb`. For tips on `alpha`, see `EWMMeanAIS`.""" class EWMStdAOS(tp.NamedTuple): mean_x: float mean_y: float cov: float sum_wt: float sum_wt2: float old_wt: float nobs: int value: float __pdoc__[ "EWMStdAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.ewm_std_acc_nb`.""" class VidyaAIS(tp.NamedTuple): i: int prev_value: float value: float pre_window_prev_value: float pre_window_value: float pos_cumsum: float neg_cumsum: float prev_vidya: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "VidyaAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.vidya_acc_nb`.""" class VidyaAOS(tp.NamedTuple): pos_cumsum: float neg_cumsum: float nancnt: int window_len: int cmo: float vidya: float __pdoc__[ "VidyaAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.vidya_acc_nb`.""" class RollCovAIS(tp.NamedTuple): i: int value1: float value2: float pre_window_value1: float pre_window_value2: float cumsum1: float cumsum2: float cumsum_prod: float nancnt: int window: int minp: tp.Optional[int] ddof: int __pdoc__[ "RollCovAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_cov_acc_nb`.""" class RollCovAOS(tp.NamedTuple): cumsum1: float cumsum2: float cumsum_prod: float nancnt: int window_len: int value: float __pdoc__[ "RollCovAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_cov_acc_nb`.""" class RollCorrAIS(tp.NamedTuple): i: int value1: float value2: float pre_window_value1: float pre_window_value2: float cumsum1: float cumsum2: float cumsum_sq1: float cumsum_sq2: float cumsum_prod: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "RollCorrAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_corr_acc_nb`.""" class RollCorrAOS(tp.NamedTuple): cumsum1: float cumsum2: float cumsum_sq1: float cumsum_sq2: float cumsum_prod: float nancnt: int window_len: int value: float __pdoc__[ "RollCorrAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_corr_acc_nb`.""" class RollOLSAIS(tp.NamedTuple): i: int value1: float value2: float pre_window_value1: float pre_window_value2: float validcnt: int cumsum1: float cumsum2: float cumsum_sq1: float cumsum_prod: float nancnt: int window: int minp: tp.Optional[int] __pdoc__[ "RollOLSAIS" ] = """A named tuple representing the input state of `vectorbtpro.generic.nb.rolling.rolling_ols_acc_nb`.""" class RollOLSAOS(tp.NamedTuple): validcnt: int cumsum1: float cumsum2: float cumsum_sq1: float cumsum_prod: float nancnt: int window_len: int slope_value: float intercept_value: float __pdoc__[ "RollOLSAOS" ] = """A named tuple representing the output state of `vectorbtpro.generic.nb.rolling.rolling_ols_acc_nb`.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Mixin for building plots out of subplots.""" import inspect import string from collections import Counter from vectorbtpro import _typing as tp from vectorbtpro.base.indexing import ParamLoc from vectorbtpro.base.wrapping import Wrapping from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import get_dict_attr, AttrResolverMixin from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import Config, HybridConfig, merge_dicts from vectorbtpro.utils.parsing import get_func_arg_names, get_forward_args from vectorbtpro.utils.tagging import match_tags from vectorbtpro.utils.template import substitute_templates, CustomTemplate from vectorbtpro.utils.warnings_ import warn __all__ = [] class MetaPlotsBuilderMixin(type): """Metaclass for `PlotsBuilderMixin`.""" @property def subplots(cls) -> Config: """Subplots supported by `PlotsBuilderMixin.plots`.""" return cls._subplots class PlotsBuilderMixin(Base, metaclass=MetaPlotsBuilderMixin): """Mixin that implements `PlotsBuilderMixin.plots`. Required to be a subclass of `vectorbtpro.base.wrapping.Wrapping`.""" _writeable_attrs: tp.WriteableAttrs = {"_subplots"} def __init__(self) -> None: checks.assert_instance_of(self, Wrapping) # Copy writeable attrs self._subplots = type(self)._subplots.copy() @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `PlotsBuilderMixin.plots`.""" return dict(settings=dict(freq=self.wrapper.freq)) def resolve_plots_setting( self, value: tp.Optional[tp.Any], key: str, merge: bool = False, ) -> tp.Any: """Resolve a setting for `PlotsBuilderMixin.plots`.""" from vectorbtpro._settings import settings as _settings plots_builder_cfg = _settings["plots_builder"] if merge: return merge_dicts( plots_builder_cfg[key], self.plots_defaults.get(key, {}), value, ) if value is not None: return value return self.plots_defaults.get(key, plots_builder_cfg[key]) _subplots: tp.ClassVar[Config] = HybridConfig(dict()) @property def subplots(self) -> Config: """Subplots supported by `${cls_name}`. ```python ${subplots} ``` Returns `${cls_name}._subplots`, which gets (hybrid-) copied upon creation of each instance. Thus, changing this config won't affect the class. To change subplots, you can either change the config in-place, override this property, or overwrite the instance variable `${cls_name}._subplots`.""" return self._subplots def plots( self, subplots: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[str, tp.Kwargs]]]] = None, tags: tp.Optional[tp.MaybeIterable[str]] = None, column: tp.Optional[tp.Label] = None, group_by: tp.GroupByLike = None, per_column: tp.Optional[bool] = None, split_columns: tp.Optional[bool] = None, silence_warnings: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, settings: tp.KwargsLike = None, filters: tp.KwargsLike = None, subplot_settings: tp.KwargsLike = None, show_titles: bool = None, show_legend: tp.Optional[bool] = None, show_column_label: tp.Optional[bool] = None, hide_id_labels: bool = None, group_id_labels: bool = None, make_subplots_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.Optional[tp.BaseFigure]: """Plot various parts of this object. Args: subplots (str, tuple, iterable, or dict): Subplots to plot. Each element can be either: * Subplot name (see keys in `PlotsBuilderMixin.subplots`) * Tuple of a subplot name and a settings dict as in `PlotsBuilderMixin.subplots` * Tuple of a subplot name and a template of instance `vectorbtpro.utils.template.CustomTemplate` * Tuple of a subplot name and a list of settings dicts to be expanded into multiple subplots The settings dict can contain the following keys: * `title`: Title of the subplot. Defaults to the name. * `plot_func` (required): Plotting function for custom subplots. Must write the supplied figure `fig` in-place and can return anything (it won't be used). * `xaxis_kwargs`: Layout keyword arguments for the x-axis. Defaults to `dict(title='Index')`. * `yaxis_kwargs`: Layout keyword arguments for the y-axis. Defaults to empty dict. * `tags`, `check_{filter}`, `inv_check_{filter}`, `resolve_plot_func`, `pass_{arg}`, `resolve_path_{arg}`, `resolve_{arg}` and `template_context`: The same as in `vectorbtpro.generic.stats_builder.StatsBuilderMixin` for `calc_func`. * Any other keyword argument that overrides the settings or is passed directly to `plot_func`. If `resolve_plot_func` is True, the plotting function may "request" any of the following arguments by accepting them or if `pass_{arg}` was found in the settings dict: * Each of `vectorbtpro.utils.attr_.AttrResolverMixin.self_aliases`: original object (ungrouped, with no column selected) * `group_by`: won't be passed if it was used in resolving the first attribute of `plot_func` specified as a path, use `pass_group_by=True` to pass anyway * `column` * `subplot_name` * `trace_names`: list with the subplot name, can't be used in templates * `add_trace_kwargs`: dict with subplot row and column index * `xref` * `yref` * `xaxis` * `yaxis` * `x_domain` * `y_domain` * `fig` * `silence_warnings` * Any argument from `settings` * Any attribute of this object if it meant to be resolved (see `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`) !!! note Layout-related resolution arguments such as `add_trace_kwargs` are unavailable before filtering and thus cannot be used in any templates but can still be overridden. Pass `subplots='all'` to plot all supported subplots. tags (str or iterable): See `tags` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. column (str): See `column` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. group_by (any): See `group_by` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. per_column (bool): See `per_column` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. split_columns (bool): See `split_columns` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. silence_warnings (bool): See `silence_warnings` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. template_context (mapping): See `template_context` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. Applied on `settings`, `make_subplots_kwargs`, and `layout_kwargs`, and then on each subplot settings. filters (dict): See `filters` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. settings (dict): See `settings` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. subplot_settings (dict): See `metric_settings` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. show_titles (bool): Whether to show the title of each subplot. show_legend (bool): Whether to show legend. If None and plotting per column, becomes False, otherwise True. show_column_label (bool): Whether to show the column label next to each legend label. If None and plotting per column, becomes True, otherwise False. hide_id_labels (bool): Whether to hide identical legend labels. Two labels are identical if their name, marker style and line style match. group_id_labels (bool): Whether to group identical legend labels. make_subplots_kwargs (dict): Keyword arguments passed to `plotly.subplots.make_subplots`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments used to update the layout of the figure. !!! note `PlotsBuilderMixin` and `vectorbtpro.generic.stats_builder.StatsBuilderMixin` are very similar. Some artifacts follow the same concept, just named differently: * `plots_defaults` vs `stats_defaults` * `subplots` vs `metrics` * `subplot_settings` vs `metric_settings` See further notes under `vectorbtpro.generic.stats_builder.StatsBuilderMixin`. """ # Plot per column if column is None: if per_column is None: per_column = self.resolve_plots_setting(per_column, "per_column") if per_column: columns = self.get_item_keys(group_by=group_by) if len(columns) > 1: if split_columns is None: split_columns = self.resolve_plots_setting(split_columns, "split_columns") if show_legend is None: show_legend = self.resolve_plots_setting(show_legend, "show_legend") if show_legend is None: show_legend = False if show_column_label is None: show_column_label = self.resolve_plots_setting(show_column_label, "show_column_label") if show_column_label is None: show_column_label = True fig = None if split_columns: for _, column_self in self.items(group_by=group_by, wrap=True): _args, _kwargs = get_forward_args(column_self.plots, locals()) fig = column_self.plots(*_args, **_kwargs) else: for column in columns: _args, _kwargs = get_forward_args(self.plots, locals()) fig = self.plots(*_args, **_kwargs) return fig from vectorbtpro.utils.figure import make_subplots, get_domain from vectorbtpro._settings import settings as _settings plotting_cfg = _settings["plotting"] # Resolve defaults silence_warnings = self.resolve_plots_setting(silence_warnings, "silence_warnings") show_titles = self.resolve_plots_setting(show_titles, "show_titles") show_legend = self.resolve_plots_setting(show_legend, "show_legend") if show_legend is None: show_legend = True show_column_label = self.resolve_plots_setting(show_column_label, "show_column_label") if show_column_label is None: show_column_label = False hide_id_labels = self.resolve_plots_setting(hide_id_labels, "hide_id_labels") group_id_labels = self.resolve_plots_setting(group_id_labels, "group_id_labels") template_context = self.resolve_plots_setting(template_context, "template_context", merge=True) filters = self.resolve_plots_setting(filters, "filters", merge=True) settings = self.resolve_plots_setting(settings, "settings", merge=True) subplot_settings = self.resolve_plots_setting(subplot_settings, "subplot_settings", merge=True) make_subplots_kwargs = self.resolve_plots_setting(make_subplots_kwargs, "make_subplots_kwargs", merge=True) layout_kwargs = self.resolve_plots_setting(layout_kwargs, "layout_kwargs", merge=True) # Replace templates globally (not used at subplot level) if len(template_context) > 0: sub_settings = substitute_templates( settings, context=template_context, eval_id="sub_settings", strict=False, ) sub_make_subplots_kwargs = substitute_templates( make_subplots_kwargs, context=template_context, eval_id="sub_make_subplots_kwargs", ) sub_layout_kwargs = substitute_templates( layout_kwargs, context=template_context, eval_id="sub_layout_kwargs", ) else: sub_settings = settings sub_make_subplots_kwargs = make_subplots_kwargs sub_layout_kwargs = layout_kwargs # Resolve self reself = self.resolve_self( cond_kwargs=sub_settings, impacts_caching=False, silence_warnings=silence_warnings, ) # Prepare subplots if subplots is None: subplots = reself.resolve_plots_setting(subplots, "subplots") if subplots == "all": subplots = reself.subplots if isinstance(subplots, dict): subplots = list(subplots.items()) if isinstance(subplots, (str, tuple)): subplots = [subplots] # Prepare tags if tags is None: tags = reself.resolve_plots_setting(tags, "tags") if isinstance(tags, str) and tags == "all": tags = None if isinstance(tags, (str, tuple)): tags = [tags] # Bring to the same shape new_subplots = [] for i, subplot in enumerate(subplots): if isinstance(subplot, str): subplot = (subplot, reself.subplots[subplot]) if not isinstance(subplot, tuple): raise TypeError(f"Subplot at index {i} must be either a string or a tuple") new_subplots.append(subplot) subplots = new_subplots # Expand subplots new_subplots = [] for i, (subplot_name, _subplot_settings) in enumerate(subplots): if isinstance(_subplot_settings, CustomTemplate): subplot_context = merge_dicts( template_context, {name: reself for name in reself.self_aliases}, dict( column=column, group_by=group_by, subplot_name=subplot_name, silence_warnings=silence_warnings, ), settings, ) subplot_context = substitute_templates( subplot_context, context=subplot_context, eval_id="subplot_context", ) _subplot_settings = _subplot_settings.substitute( context=subplot_context, strict=True, eval_id="subplot", ) if isinstance(_subplot_settings, list): for __subplot_settings in _subplot_settings: new_subplots.append((subplot_name, __subplot_settings)) else: new_subplots.append((subplot_name, _subplot_settings)) subplots = new_subplots # Handle duplicate names subplot_counts = Counter(list(map(lambda x: x[0], subplots))) subplot_i = {k: -1 for k in subplot_counts.keys()} subplots_dct = {} for i, (subplot_name, _subplot_settings) in enumerate(subplots): if subplot_counts[subplot_name] > 1: subplot_i[subplot_name] += 1 subplot_name = subplot_name + "_" + str(subplot_i[subplot_name]) subplots_dct[subplot_name] = _subplot_settings # Check subplot_settings missed_keys = set(subplot_settings.keys()).difference(set(subplots_dct.keys())) if len(missed_keys) > 0: raise ValueError(f"Keys {missed_keys} in subplot_settings could not be matched with any subplot") # Merge settings opt_arg_names_dct = {} custom_arg_names_dct = {} resolved_self_dct = {} context_dct = {} for subplot_name, _subplot_settings in list(subplots_dct.items()): opt_settings = merge_dicts( {name: reself for name in reself.self_aliases}, dict( column=column, group_by=group_by, subplot_name=subplot_name, trace_names=[subplot_name], silence_warnings=silence_warnings, ), settings, ) _subplot_settings = _subplot_settings.copy() passed_subplot_settings = subplot_settings.get(subplot_name, {}) merged_settings = merge_dicts(opt_settings, _subplot_settings, passed_subplot_settings) subplot_template_context = merged_settings.pop("template_context", {}) template_context_merged = merge_dicts(template_context, subplot_template_context) template_context_merged = substitute_templates( template_context_merged, context=merged_settings, eval_id="template_context_merged", ) context = merge_dicts(template_context_merged, merged_settings) # safe because we will use substitute_templates again once layout params are known merged_settings = substitute_templates( merged_settings, context=context, eval_id="merged_settings", ) # Filter by tag if tags is not None: in_tags = merged_settings.get("tags", None) if in_tags is None or not match_tags(tags, in_tags): subplots_dct.pop(subplot_name, None) continue custom_arg_names = set(_subplot_settings.keys()).union(set(passed_subplot_settings.keys())) opt_arg_names = set(opt_settings.keys()) custom_reself = reself.resolve_self( cond_kwargs=merged_settings, custom_arg_names=custom_arg_names, impacts_caching=True, silence_warnings=merged_settings["silence_warnings"], ) subplots_dct[subplot_name] = merged_settings custom_arg_names_dct[subplot_name] = custom_arg_names opt_arg_names_dct[subplot_name] = opt_arg_names resolved_self_dct[subplot_name] = custom_reself context_dct[subplot_name] = context # Filter subplots for subplot_name, _subplot_settings in list(subplots_dct.items()): custom_reself = resolved_self_dct[subplot_name] context = context_dct[subplot_name] _silence_warnings = _subplot_settings.get("silence_warnings") subplot_filters = set() for k in _subplot_settings.keys(): filter_name = None if k.startswith("check_"): filter_name = k[len("check_") :] elif k.startswith("inv_check_"): filter_name = k[len("inv_check_") :] if filter_name is not None: if filter_name not in filters: raise ValueError(f"Metric '{subplot_name}' requires filter '{filter_name}'") subplot_filters.add(filter_name) for filter_name in subplot_filters: filter_settings = filters[filter_name] _filter_settings = substitute_templates( filter_settings, context=context, eval_id="filter_settings", ) filter_func = _filter_settings["filter_func"] warning_message = _filter_settings.get("warning_message", None) inv_warning_message = _filter_settings.get("inv_warning_message", None) to_check = _subplot_settings.get("check_" + filter_name, False) inv_to_check = _subplot_settings.get("inv_check_" + filter_name, False) if to_check or inv_to_check: whether_true = filter_func(custom_reself, _subplot_settings) to_remove = (to_check and not whether_true) or (inv_to_check and whether_true) if to_remove: if to_check and warning_message is not None and not _silence_warnings: warn(warning_message) if inv_to_check and inv_warning_message is not None and not _silence_warnings: warn(inv_warning_message) subplots_dct.pop(subplot_name, None) custom_arg_names_dct.pop(subplot_name, None) opt_arg_names_dct.pop(subplot_name, None) resolved_self_dct.pop(subplot_name, None) context_dct.pop(subplot_name, None) break # Any subplots left? if len(subplots_dct) == 0: if not silence_warnings: warn("No subplots to plot") return None # Set up figure rows = sub_make_subplots_kwargs.pop("rows", len(subplots_dct)) cols = sub_make_subplots_kwargs.pop("cols", 1) specs = sub_make_subplots_kwargs.pop( "specs", [[{} for _ in range(cols)] for _ in range(rows)], ) row_col_tuples = [] for row, row_spec in enumerate(specs): for col, col_spec in enumerate(row_spec): if col_spec is not None: row_col_tuples.append((row + 1, col + 1)) shared_xaxes = sub_make_subplots_kwargs.pop("shared_xaxes", True) shared_yaxes = sub_make_subplots_kwargs.pop("shared_yaxes", False) default_height = plotting_cfg["layout"]["height"] default_width = plotting_cfg["layout"]["width"] + 50 min_space = 10 # space between subplots with no axis sharing max_title_spacing = 30 max_xaxis_spacing = 50 max_yaxis_spacing = 100 legend_height = 50 if show_titles: title_spacing = max_title_spacing else: title_spacing = 0 if not shared_xaxes and rows > 1: xaxis_spacing = max_xaxis_spacing else: xaxis_spacing = 0 if not shared_yaxes and cols > 1: yaxis_spacing = max_yaxis_spacing else: yaxis_spacing = 0 if "height" in sub_layout_kwargs: height = sub_layout_kwargs.pop("height") else: height = default_height + title_spacing if rows > 1: height *= rows height += min_space * rows - min_space height += legend_height - legend_height * rows if shared_xaxes: height += max_xaxis_spacing - max_xaxis_spacing * rows if "width" in sub_layout_kwargs: width = sub_layout_kwargs.pop("width") else: width = default_width if cols > 1: width *= cols width += min_space * cols - min_space if shared_yaxes: width += max_yaxis_spacing - max_yaxis_spacing * cols if height is not None: if "vertical_spacing" in sub_make_subplots_kwargs: vertical_spacing = sub_make_subplots_kwargs.pop("vertical_spacing") else: vertical_spacing = min_space + title_spacing + xaxis_spacing if vertical_spacing is not None and vertical_spacing > 1: vertical_spacing /= height legend_y = 1 + (min_space + title_spacing) / height else: vertical_spacing = sub_make_subplots_kwargs.pop("vertical_spacing", None) legend_y = 1.02 if width is not None: if "horizontal_spacing" in sub_make_subplots_kwargs: horizontal_spacing = sub_make_subplots_kwargs.pop("horizontal_spacing") else: horizontal_spacing = min_space + yaxis_spacing if horizontal_spacing is not None and horizontal_spacing > 1: horizontal_spacing /= width else: horizontal_spacing = sub_make_subplots_kwargs.pop("horizontal_spacing", None) if show_titles: _subplot_titles = [] for i in range(len(subplots_dct)): _subplot_titles.append("$title_" + str(i)) else: _subplot_titles = None if fig is None: fig = make_subplots( rows=rows, cols=cols, specs=specs, shared_xaxes=shared_xaxes, shared_yaxes=shared_yaxes, subplot_titles=_subplot_titles, vertical_spacing=vertical_spacing, horizontal_spacing=horizontal_spacing, **sub_make_subplots_kwargs, ) sub_layout_kwargs = merge_dicts( dict( showlegend=True, width=width, height=height, legend=dict( orientation="h", yanchor="bottom", y=legend_y, xanchor="right", x=1, traceorder="normal", ), ), sub_layout_kwargs, ) trace_start_idx = 0 else: trace_start_idx = len(fig.data) fig.update_layout(**sub_layout_kwargs) # Plot subplots arg_cache_dct = {} for i, (subplot_name, _subplot_settings) in enumerate(subplots_dct.items()): try: final_kwargs = _subplot_settings.copy() opt_arg_names = opt_arg_names_dct[subplot_name] custom_arg_names = custom_arg_names_dct[subplot_name] custom_reself = resolved_self_dct[subplot_name] context = context_dct[subplot_name] # Compute figure artifacts row, col = row_col_tuples[i] xref = "x" if i == 0 else "x" + str(i + 1) yref = "y" if i == 0 else "y" + str(i + 1) xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] x_domain = get_domain(xref, fig) y_domain = get_domain(yref, fig) subplot_layout_kwargs = dict( add_trace_kwargs=dict(row=row, col=col), xref=xref, yref=yref, xaxis=xaxis, yaxis=yaxis, x_domain=x_domain, y_domain=y_domain, fig=fig, pass_fig=True, # force passing fig ) for k in subplot_layout_kwargs: opt_arg_names.add(k) if k in final_kwargs: custom_arg_names.add(k) final_kwargs = merge_dicts(subplot_layout_kwargs, final_kwargs) context = merge_dicts(subplot_layout_kwargs, context) final_kwargs = substitute_templates(final_kwargs, context=context, eval_id="final_kwargs") # Clean up keys for k, v in list(final_kwargs.items()): if k.startswith("check_") or k.startswith("inv_check_") or k in ("tags",): final_kwargs.pop(k, None) # Get subplot-specific values _column = final_kwargs.get("column") _group_by = final_kwargs.get("group_by") _silence_warnings = final_kwargs.get("silence_warnings") title = final_kwargs.pop("title", subplot_name) plot_func = final_kwargs.pop("plot_func", None) xaxis_kwargs = final_kwargs.pop("xaxis_kwargs", None) yaxis_kwargs = final_kwargs.pop("yaxis_kwargs", None) resolve_plot_func = final_kwargs.pop("resolve_plot_func", True) use_shortcuts = final_kwargs.pop("use_shortcuts", True) use_caching = final_kwargs.pop("use_caching", True) if plot_func is not None: # Resolve plot_func if resolve_plot_func: if not callable(plot_func): passed_kwargs_out = {} def _getattr_func( obj: tp.Any, attr: str, args: tp.ArgsLike = None, kwargs: tp.KwargsLike = None, call_attr: bool = True, _final_kwargs: tp.Kwargs = final_kwargs, _opt_arg_names: tp.Set[str] = opt_arg_names, _custom_arg_names: tp.Set[str] = custom_arg_names, _arg_cache_dct: tp.Kwargs = arg_cache_dct, _use_shortcuts: bool = use_shortcuts, _use_caching: bool = use_caching, ) -> tp.Any: if attr in _final_kwargs: return _final_kwargs[attr] if args is None: args = () if kwargs is None: kwargs = {} if obj is custom_reself: resolve_path_arg = _final_kwargs.pop( "resolve_path_" + attr, True, ) if resolve_path_arg: if call_attr: cond_kwargs = { k: v for k, v in _final_kwargs.items() if k in _opt_arg_names } out = custom_reself.resolve_attr( attr, # do not pass _attr, important for caching args=args, cond_kwargs=cond_kwargs, kwargs=kwargs, custom_arg_names=_custom_arg_names, cache_dct=_arg_cache_dct, use_caching=_use_caching, passed_kwargs_out=passed_kwargs_out, use_shortcuts=_use_shortcuts, ) else: if isinstance(obj, AttrResolverMixin): cls_dir = obj.cls_dir else: cls_dir = dir(type(obj)) if "get_" + attr in cls_dir: _attr = "get_" + attr else: _attr = attr out = getattr(obj, _attr) _select_col_arg = _final_kwargs.pop( "select_col_" + attr, False, ) if _select_col_arg and _column is not None: out = custom_reself.select_col_from_obj( out, _column, wrapper=custom_reself.wrapper.regroup(_group_by), ) passed_kwargs_out["group_by"] = _group_by passed_kwargs_out["column"] = _column return out out = getattr(obj, attr) if callable(out) and call_attr: return out(*args, **kwargs) return out plot_func = custom_reself.deep_getattr( plot_func, getattr_func=_getattr_func, call_last_attr=False, ) if "group_by" in passed_kwargs_out: if "pass_group_by" not in final_kwargs: final_kwargs.pop("group_by", None) if "column" in passed_kwargs_out: if "pass_column" not in final_kwargs: final_kwargs.pop("column", None) if not callable(plot_func): raise TypeError("plot_func must be callable") # Resolve arguments func_arg_names = get_func_arg_names(plot_func) for k in func_arg_names: if k not in final_kwargs: resolve_arg = final_kwargs.pop("resolve_" + k, False) use_shortcuts_arg = final_kwargs.pop("use_shortcuts_" + k, True) select_col_arg = final_kwargs.pop("select_col_" + k, False) if resolve_arg: try: arg_out = custom_reself.resolve_attr( k, cond_kwargs=final_kwargs, custom_arg_names=custom_arg_names, cache_dct=arg_cache_dct, use_caching=use_caching, use_shortcuts=use_shortcuts_arg, ) except AttributeError: continue if select_col_arg and _column is not None: arg_out = custom_reself.select_col_from_obj( arg_out, _column, wrapper=custom_reself.wrapper.regroup(_group_by), ) final_kwargs[k] = arg_out for k in list(final_kwargs.keys()): if k in opt_arg_names: if "pass_" + k in final_kwargs: if not final_kwargs.get("pass_" + k): # first priority final_kwargs.pop(k, None) elif k not in func_arg_names: # second priority final_kwargs.pop(k, None) for k in list(final_kwargs.keys()): if k.startswith("pass_") or k.startswith("resolve_"): final_kwargs.pop(k, None) # cleanup # Call plot_func plot_func(**final_kwargs) else: # Do not resolve plot_func plot_func(custom_reself, _subplot_settings) # Update global layout for annotation in fig.layout.annotations: if "text" in annotation and annotation["text"] == "$title_" + str(i): annotation.update(text=title) subplot_layout = dict() subplot_layout[xaxis] = merge_dicts(dict(title="Index"), xaxis_kwargs) subplot_layout[yaxis] = merge_dicts(dict(), yaxis_kwargs) fig.update_layout(**subplot_layout) except Exception as e: warn(f"Subplot '{subplot_name}' raised an exception") raise e # Hide legend labels if not show_legend: for i in range(trace_start_idx, len(fig.data)): fig.data[i].update(showlegend=False) # Show column label if show_column_label: if column is not None: _column = column else: _column = reself.wrapper.get_columns(group_by=group_by)[0] for i in range(trace_start_idx, len(fig.data)): trace = fig.data[i] if trace["name"] is not None: trace.update(name=trace["name"] + f" [{ParamLoc.encode_key(_column)}]") # Remove duplicate legend labels found_ids = dict() unique_idx = trace_start_idx for i in range(trace_start_idx, len(fig.data)): trace = fig.data[i] if trace["showlegend"] is not False and trace["legendgroup"] is None: if "name" in trace: name = trace["name"] else: name = None if "marker" in trace: marker = trace["marker"] else: marker = {} if "symbol" in marker: marker_symbol = marker["symbol"] else: marker_symbol = None if "color" in marker: marker_color = marker["color"] else: marker_color = None if "line" in trace: line = trace["line"] else: line = {} if "dash" in line: line_dash = line["dash"] else: line_dash = None if "color" in line: line_color = line["color"] else: line_color = None id = (name, marker_symbol, marker_color, line_dash, line_color) if id in found_ids: if hide_id_labels: trace.update(showlegend=False) if group_id_labels: trace.update(legendgroup=found_ids[id]) else: if group_id_labels: trace.update(legendgroup=unique_idx) found_ids[id] = unique_idx unique_idx += 1 # Hide identical legend labels if hide_id_labels: legendgroups = set() for i in range(trace_start_idx, len(fig.data)): trace = fig.data[i] if trace["legendgroup"] is not None: if trace["showlegend"]: if trace["legendgroup"] in legendgroups: trace.update(showlegend=False) else: legendgroups.add(trace["legendgroup"]) # Remove all except the last title if sharing the same axis if shared_xaxes: i = 0 for row in range(rows): for col in range(cols): if specs[row][col] is not None: xaxis = "xaxis" if i == 0 else "xaxis" + str(i + 1) if row < rows - 1: fig.layout[xaxis].update(title=None) i += 1 if shared_yaxes: i = 0 for row in range(rows): for col in range(cols): if specs[row][col] is not None: yaxis = "yaxis" if i == 0 else "yaxis" + str(i + 1) if col > 0: fig.layout[yaxis].update(title=None) i += 1 # Return the figure return fig # ############# Docs ############# # @classmethod def build_subplots_doc(cls, source_cls: tp.Optional[type] = None) -> str: """Build subplots documentation.""" if source_cls is None: source_cls = PlotsBuilderMixin return string.Template( inspect.cleandoc(get_dict_attr(source_cls, "subplots").__doc__), ).substitute( {"subplots": cls.subplots.prettify(), "cls_name": cls.__name__}, ) @classmethod def override_subplots_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: """Call this method on each subclass that overrides `PlotsBuilderMixin.subplots`.""" __pdoc__[cls.__name__ + ".subplots"] = cls.build_subplots_doc(source_cls=source_cls) __pdoc__ = dict() PlotsBuilderMixin.override_subplots_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base plotting functions. Provides functions for visualizing data in an efficient and convenient way. Each creates a figure widget that is compatible with ipywidgets and enables interactive data visualization in Jupyter Notebook and JupyterLab environments. For more details on using Plotly, see [Getting Started with Plotly in Python](https://plotly.com/python/getting-started/). !!! warning Errors related to plotting in Jupyter environment usually appear in the logs, not under the cell.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import math import numpy as np import pandas as pd import plotly.graph_objects as go from plotly.basedatatypes import BaseTraceType from vectorbtpro import _typing as tp from vectorbtpro.base import reshaping from vectorbtpro.utils import checks from vectorbtpro.utils.array_ import rescale from vectorbtpro.utils.base import Base from vectorbtpro.utils.colors import map_value_to_cmap from vectorbtpro.utils.config import Configured, resolve_dict, merge_dicts from vectorbtpro.utils.figure import make_figure __all__ = [ "TraceUpdater", "Gauge", "Bar", "Scatter", "Histogram", "Box", "Heatmap", "Volume", ] def clean_labels(labels: tp.Labels) -> tp.Labels: """Clean labels. Plotly doesn't support multi-indexes.""" if isinstance(labels, pd.MultiIndex): labels = labels.to_flat_index() if isinstance(labels, pd.PeriodIndex): labels = labels.map(str) if len(labels) > 0 and isinstance(labels[0], tuple): labels = list(map(str, labels)) return labels class TraceType(Configured): """Class representing a trace type.""" _expected_keys_mode: tp.ExpectedKeysMode = "disable" class TraceUpdater(Base): """Class for updating traces.""" def __init__(self, fig: tp.BaseFigure, traces: tp.Tuple[BaseTraceType, ...]) -> None: self._fig = fig self._traces = traces @property def fig(self) -> tp.BaseFigure: """Figure.""" return self._fig @property def traces(self) -> tp.Tuple[BaseTraceType, ...]: """Traces to update.""" return self._traces @classmethod def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, *args, **kwargs) -> None: """Update one trace.""" raise NotImplementedError def update(self, *args, **kwargs) -> None: """Update all traces using new data.""" raise NotImplementedError class Gauge(TraceType, TraceUpdater): """Gauge plot. Args: value (float): The value to be displayed. label (str): The label to be displayed. value_range (tuple of float): The value range of the gauge. cmap_name (str): A matplotlib-compatible colormap name. See the [list of available colormaps](https://matplotlib.org/tutorials/colors/colormaps.html). trace_kwargs (dict): Keyword arguments passed to the `plotly.graph_objects.Indicator`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> from vectorbtpro import * >>> gauge = vbt.Gauge( ... value=2, ... value_range=(1, 3), ... label='My Gauge' ... ) >>> gauge.fig.show() ``` ![](/assets/images/api/Gauge.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Gauge.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, value: tp.Optional[float] = None, label: tp.Optional[str] = None, value_range: tp.Optional[tp.Tuple[float, float]] = None, cmap_name: str = "Spectral", trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, value=value, label=label, value_range=value_range, cmap_name=cmap_name, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) from vectorbtpro._settings import settings layout_cfg = settings["plotting"]["layout"] if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) if "width" in layout_cfg: # Calculate nice width and height fig.update_layout(width=layout_cfg["width"] * 0.7, height=layout_cfg["width"] * 0.5, margin=dict(t=80)) fig.update_layout(**layout_kwargs) _trace_kwargs = merge_dicts( dict( domain=dict(x=[0, 1], y=[0, 1]), mode="gauge+number+delta", title=dict(text=label), ), trace_kwargs, ) trace = go.Indicator(**_trace_kwargs) if value is not None: self.update_trace(trace, value, value_range=value_range, cmap_name=cmap_name) fig.add_trace(trace, **add_trace_kwargs) TraceUpdater.__init__(self, fig, (fig.data[-1],)) self._value_range = value_range self._cmap_name = cmap_name @property def value_range(self) -> tp.Tuple[float, float]: """The value range of the gauge.""" return self._value_range @property def cmap_name(self) -> str: """A matplotlib-compatible colormap name.""" return self._cmap_name @classmethod def update_trace( cls, trace: BaseTraceType, value: float, value_range: tp.Optional[tp.Tuple[float, float]] = None, cmap_name: str = "Spectral", ) -> None: if value_range is not None: trace.gauge.axis.range = value_range if cmap_name is not None: trace.gauge.bar.color = map_value_to_cmap(value, cmap_name, vmin=value_range[0], vmax=value_range[1]) trace.delta.reference = trace.value trace.value = value def update(self, value: float) -> None: if self.value_range is None: self._value_range = value, value else: self._value_range = min(self.value_range[0], value), max(self.value_range[1], value) with self.fig.batch_update(): self.update_trace( self.traces[0], value=value, value_range=self.value_range, cmap_name=self.cmap_name, ) class Bar(TraceType, TraceUpdater): """Bar plot. Args: data (array_like): Data in any format that can be converted to NumPy. Must be of shape (`x_labels`, `trace_names`). trace_names (str or list of str): Trace names, corresponding to columns in pandas. x_labels (array_like): X-axis labels, corresponding to index in pandas. trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Bar`. Can be specified per trace as a sequence of dicts. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> from vectorbtpro import * >>> bar = vbt.Bar( ... data=[[1, 2], [3, 4]], ... trace_names=['a', 'b'], ... x_labels=['x', 'y'] ... ) >>> bar.fig.show() ``` ![](/assets/images/api/Bar.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Bar.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, data: tp.Optional[tp.ArrayLike] = None, trace_names: tp.TraceNames = None, x_labels: tp.Optional[tp.Labels] = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, data=data, trace_names=trace_names, x_labels=x_labels, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if data is not None: data = reshaping.to_2d_array(data) if trace_names is not None: checks.assert_shape_equal(data, trace_names, (1, 0)) else: if trace_names is None: raise ValueError("At least data or trace_names must be passed") if trace_names is None: trace_names = [None] * data.shape[1] if isinstance(trace_names, str): trace_names = [trace_names] if x_labels is not None: x_labels = clean_labels(x_labels) if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) fig.update_layout(**layout_kwargs) for i, trace_name in enumerate(trace_names): _trace_kwargs = resolve_dict(trace_kwargs, i=i) trace_name = _trace_kwargs.pop("name", trace_name) if trace_name is not None: trace_name = str(trace_name) _trace_kwargs = merge_dicts( dict(x=x_labels, name=trace_name, showlegend=trace_name is not None), _trace_kwargs, ) trace = go.Bar(**_trace_kwargs) if data is not None: self.update_trace(trace, data, i) fig.add_trace(trace, **add_trace_kwargs) TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :]) @classmethod def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, i: int) -> None: data = reshaping.to_2d_array(data) trace.y = data[:, i] if trace.marker.colorscale is not None: trace.marker.color = data[:, i] def update(self, data: tp.ArrayLike) -> None: data = reshaping.to_2d_array(data) with self.fig.batch_update(): for i, trace in enumerate(self.traces): self.update_trace(trace, data, i) class Scatter(TraceType, TraceUpdater): """Scatter plot. Args: data (array_like): Data in any format that can be converted to NumPy. Must be of shape (`x_labels`, `trace_names`). trace_names (str or list of str): Trace names, corresponding to columns in pandas. x_labels (array_like): X-axis labels, corresponding to index in pandas. trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Scatter`. Can be specified per trace as a sequence of dicts. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`. fig (Figure or FigureWidget): Figure to add traces to. use_gl (bool): Whether to use `plotly.graph_objects.Scattergl`. Defaults to the global setting. If the global setting is None, becomes True if there are more than 10,000 data points. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> from vectorbtpro import * >>> scatter = vbt.Scatter( ... data=[[1, 2], [3, 4]], ... trace_names=['a', 'b'], ... x_labels=['x', 'y'] ... ) >>> scatter.fig.show() ``` ![](/assets/images/api/Scatter.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Scatter.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, data: tp.Optional[tp.ArrayLike] = None, trace_names: tp.TraceNames = None, x_labels: tp.Optional[tp.Labels] = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, use_gl: tp.Optional[bool] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, data=data, trace_names=trace_names, x_labels=x_labels, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if data is not None: data = reshaping.to_2d_array(data) if trace_names is not None: checks.assert_shape_equal(data, trace_names, (1, 0)) else: if trace_names is None: raise ValueError("At least data or trace_names must be passed") if trace_names is None: trace_names = [None] * data.shape[1] if isinstance(trace_names, str): trace_names = [trace_names] if x_labels is not None: x_labels = clean_labels(x_labels) if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) fig.update_layout(**layout_kwargs) for i, trace_name in enumerate(trace_names): _trace_kwargs = resolve_dict(trace_kwargs, i=i) _use_gl = _trace_kwargs.pop("use_gl", use_gl) if _use_gl is None: _use_gl = plotting_cfg["use_gl"] if _use_gl is None: _use_gl = _use_gl is None and data is not None and data.size >= 10000 trace_name = _trace_kwargs.pop("name", trace_name) if trace_name is not None: trace_name = str(trace_name) if _use_gl: scatter_obj = go.Scattergl else: scatter_obj = go.Scatter try: from plotly_resampler.aggregation import AbstractFigureAggregator if isinstance(fig, AbstractFigureAggregator): use_resampler = True else: use_resampler = False except ImportError: use_resampler = False if use_resampler: if data is None: raise ValueError("Cannot create empty scatter traces when using plotly-resampler") _trace_kwargs = merge_dicts( dict(name=trace_name, showlegend=trace_name is not None), _trace_kwargs, ) trace = scatter_obj(**_trace_kwargs) fig.add_trace(trace, hf_x=x_labels, hf_y=data[:, i], **add_trace_kwargs) else: _trace_kwargs = merge_dicts( dict(x=x_labels, name=trace_name, showlegend=trace_name is not None), _trace_kwargs, ) trace = scatter_obj(**_trace_kwargs) if data is not None: self.update_trace(trace, data, i) fig.add_trace(trace, **add_trace_kwargs) TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :]) @classmethod def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, i: int) -> None: data = reshaping.to_2d_array(data) trace.y = data[:, i] def update(self, data: tp.ArrayLike) -> None: data = reshaping.to_2d_array(data) with self.fig.batch_update(): for i, trace in enumerate(self.traces): self.update_trace(trace, data, i) class Histogram(TraceType, TraceUpdater): """Histogram plot. Args: data (array_like): Data in any format that can be converted to NumPy. Must be of shape (any, `trace_names`). trace_names (str or list of str): Trace names, corresponding to columns in pandas. horizontal (bool): Whether to plot horizontally. remove_nan (bool): Whether to remove NaN values. from_quantile (float): Filter out data points before this quantile. Must be in range `[0, 1]`. to_quantile (float): Filter out data points after this quantile. Must be in range `[0, 1]`. trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Histogram`. Can be specified per trace as a sequence of dicts. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> from vectorbtpro import * >>> hist = vbt.Histogram( ... data=[[1, 2], [3, 4], [2, 1]], ... trace_names=['a', 'b'] ... ) >>> hist.fig.show() ``` ![](/assets/images/api/Histogram.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Histogram.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, data: tp.Optional[tp.ArrayLike] = None, trace_names: tp.TraceNames = None, horizontal: bool = False, remove_nan: bool = True, from_quantile: tp.Optional[float] = None, to_quantile: tp.Optional[float] = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, data=data, trace_names=trace_names, horizontal=horizontal, remove_nan=remove_nan, from_quantile=from_quantile, to_quantile=to_quantile, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if data is not None: data = reshaping.to_2d_array(data) if trace_names is not None: checks.assert_shape_equal(data, trace_names, (1, 0)) else: if trace_names is None: raise ValueError("At least data or trace_names must be passed") if trace_names is None: trace_names = [None] * data.shape[1] if isinstance(trace_names, str): trace_names = [trace_names] if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) fig.update_layout(barmode="overlay") fig.update_layout(**layout_kwargs) for i, trace_name in enumerate(trace_names): _trace_kwargs = resolve_dict(trace_kwargs, i=i) trace_name = _trace_kwargs.pop("name", trace_name) if trace_name is not None: trace_name = str(trace_name) _trace_kwargs = merge_dicts( dict( opacity=0.75 if len(trace_names) > 1 else 1, name=trace_name, showlegend=trace_name is not None, ), _trace_kwargs, ) trace = go.Histogram(**_trace_kwargs) if data is not None: self.update_trace( trace, data, i, horizontal=horizontal, remove_nan=remove_nan, from_quantile=from_quantile, to_quantile=to_quantile, ) fig.add_trace(trace, **add_trace_kwargs) TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :]) self._horizontal = horizontal self._remove_nan = remove_nan self._from_quantile = from_quantile self._to_quantile = to_quantile @property def horizontal(self) -> bool: """Whether to plot horizontally.""" return self._horizontal @property def remove_nan(self) -> bool: """Whether to remove NaN values.""" return self._remove_nan @property def from_quantile(self) -> float: """Filter out data points before this quantile.""" return self._from_quantile @property def to_quantile(self) -> float: """Filter out data points after this quantile.""" return self._to_quantile @classmethod def update_trace( cls, trace: BaseTraceType, data: tp.ArrayLike, i: int, horizontal: bool = False, remove_nan: bool = True, from_quantile: tp.Optional[float] = None, to_quantile: tp.Optional[float] = None, ) -> None: data = reshaping.to_2d_array(data) d = data[:, i] if remove_nan: d = d[~np.isnan(d)] mask = np.full(d.shape, True) if from_quantile is not None: mask &= d >= np.quantile(d, from_quantile) if to_quantile is not None: mask &= d <= np.quantile(d, to_quantile) d = d[mask] if horizontal: trace.x = None trace.y = d else: trace.x = d trace.y = None def update(self, data: tp.ArrayLike) -> None: data = reshaping.to_2d_array(data) with self.fig.batch_update(): for i, trace in enumerate(self.traces): self.update_trace( trace, data, i, horizontal=self.horizontal, remove_nan=self.remove_nan, from_quantile=self.from_quantile, to_quantile=self.to_quantile, ) class Box(TraceType, TraceUpdater): """Box plot. For keyword arguments, see `Histogram`. Usage: ```pycon >>> from vectorbtpro import * >>> box = vbt.Box( ... data=[[1, 2], [3, 4], [2, 1]], ... trace_names=['a', 'b'] ... ) >>> box.fig.show() ``` ![](/assets/images/api/Box.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Box.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, data: tp.Optional[tp.ArrayLike] = None, trace_names: tp.TraceNames = None, horizontal: bool = False, remove_nan: bool = True, from_quantile: tp.Optional[float] = None, to_quantile: tp.Optional[float] = None, trace_kwargs: tp.KwargsLikeSequence = None, add_trace_kwargs: tp.KwargsLike = None, make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, data=data, trace_names=trace_names, horizontal=horizontal, remove_nan=remove_nan, from_quantile=from_quantile, to_quantile=to_quantile, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if data is not None: data = reshaping.to_2d_array(data) if trace_names is not None: checks.assert_shape_equal(data, trace_names, (1, 0)) else: if trace_names is None: raise ValueError("At least data or trace_names must be passed") if trace_names is None: trace_names = [None] * data.shape[1] if isinstance(trace_names, str): trace_names = [trace_names] if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) fig.update_layout(**layout_kwargs) for i, trace_name in enumerate(trace_names): _trace_kwargs = resolve_dict(trace_kwargs, i=i) trace_name = _trace_kwargs.pop("name", trace_name) if trace_name is not None: trace_name = str(trace_name) _trace_kwargs = merge_dicts( dict(name=trace_name, showlegend=trace_name is not None, boxmean="sd"), _trace_kwargs, ) trace = go.Box(**_trace_kwargs) if data is not None: self.update_trace( trace, data, i, horizontal=horizontal, remove_nan=remove_nan, from_quantile=from_quantile, to_quantile=to_quantile, ) fig.add_trace(trace, **add_trace_kwargs) TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :]) self._horizontal = horizontal self._remove_nan = remove_nan self._from_quantile = from_quantile self._to_quantile = to_quantile @property def horizontal(self) -> bool: """Whether to plot horizontally.""" return self._horizontal @property def remove_nan(self) -> bool: """Whether to remove NaN values.""" return self._remove_nan @property def from_quantile(self) -> float: """Filter out data points before this quantile.""" return self._from_quantile @property def to_quantile(self) -> float: """Filter out data points after this quantile.""" return self._to_quantile @classmethod def update_trace( cls, trace: BaseTraceType, data: tp.ArrayLike, i: int, horizontal: bool = False, remove_nan: bool = True, from_quantile: tp.Optional[float] = None, to_quantile: tp.Optional[float] = None, ) -> None: data = reshaping.to_2d_array(data) d = data[:, i] if remove_nan: d = d[~np.isnan(d)] mask = np.full(d.shape, True) if from_quantile is not None: mask &= d >= np.quantile(d, from_quantile) if to_quantile is not None: mask &= d <= np.quantile(d, to_quantile) d = d[mask] if horizontal: trace.x = d trace.y = None else: trace.x = None trace.y = d def update(self, data: tp.ArrayLike) -> None: data = reshaping.to_2d_array(data) with self.fig.batch_update(): for i, trace in enumerate(self.traces): self.update_trace( trace, data, i, horizontal=self.horizontal, remove_nan=self.remove_nan, from_quantile=self.from_quantile, to_quantile=self.to_quantile, ) class Heatmap(TraceType, TraceUpdater): """Heatmap plot. Args: data (array_like): Data in any format that can be converted to NumPy. Must be of shape (`y_labels`, `x_labels`). x_labels (array_like): X-axis labels, corresponding to columns in pandas. y_labels (array_like): Y-axis labels, corresponding to index in pandas. is_x_category (bool): Whether X-axis is a categorical axis. is_y_category (bool): Whether Y-axis is a categorical axis. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> from vectorbtpro import * >>> heatmap = vbt.Heatmap( ... data=[[1, 2], [3, 4]], ... x_labels=['a', 'b'], ... y_labels=['x', 'y'] ... ) >>> heatmap.fig.show() ``` ![](/assets/images/api/Heatmap.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Heatmap.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, data: tp.Optional[tp.ArrayLike] = None, x_labels: tp.Optional[tp.Labels] = None, y_labels: tp.Optional[tp.Labels] = None, is_x_category: bool = False, is_y_category: bool = False, trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, data=data, x_labels=x_labels, y_labels=y_labels, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) from vectorbtpro._settings import settings layout_cfg = settings["plotting"]["layout"] if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if data is not None: data = reshaping.to_2d_array(data) if x_labels is not None: checks.assert_shape_equal(data, x_labels, (1, 0)) if y_labels is not None: checks.assert_shape_equal(data, y_labels, (0, 0)) else: if x_labels is None or y_labels is None: raise ValueError("At least data, or x_labels and y_labels must be passed") if x_labels is not None: x_labels = clean_labels(x_labels) if y_labels is not None: y_labels = clean_labels(y_labels) if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) if "width" in layout_cfg: # Calculate nice width and height max_width = layout_cfg["width"] if data is not None: x_len = data.shape[1] y_len = data.shape[0] else: x_len = len(x_labels) y_len = len(y_labels) width = math.ceil(rescale(x_len / (x_len + y_len), (0, 1), (0.3 * max_width, max_width))) width = min(width + 150, max_width) # account for colorbar height = math.ceil(rescale(y_len / (x_len + y_len), (0, 1), (0.3 * max_width, max_width))) height = min(height, max_width * 0.7) # limit height fig.update_layout(width=width, height=height) _trace_kwargs = merge_dicts( dict(hoverongaps=False, colorscale="Plasma", x=x_labels, y=y_labels), trace_kwargs, ) trace = go.Heatmap(**_trace_kwargs) if data is not None: self.update_trace(trace, data) fig.add_trace(trace, **add_trace_kwargs) xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] axis_kwargs = dict() if is_x_category: axis_kwargs[xaxis] = dict(type="category") if is_y_category: axis_kwargs[yaxis] = dict(type="category") fig.update_layout(**axis_kwargs) fig.update_layout(**layout_kwargs) TraceUpdater.__init__(self, fig, (fig.data[-1],)) @classmethod def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, *args, **kwargs) -> None: trace.z = reshaping.to_2d_array(data) def update(self, data: tp.ArrayLike) -> None: with self.fig.batch_update(): self.update_trace(self.traces[0], data) class Volume(TraceType, TraceUpdater): """Volume plot. Args: data (array_like): Data in any format that can be converted to NumPy. Must be a 3-dim array. x_labels (array_like): X-axis labels. y_labels (array_like): Y-axis labels. z_labels (array_like): Z-axis labels. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Volume`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. scene_name (str): Reference to the 3D scene. make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. !!! note Figure widgets have currently problems displaying NaNs. Use `.show()` method for rendering. Usage: ```pycon >>> from vectorbtpro import * >>> volume = vbt.Volume( ... data=np.random.randint(1, 10, size=(3, 3, 3)), ... x_labels=['a', 'b', 'c'], ... y_labels=['d', 'e', 'f'], ... z_labels=['g', 'h', 'i'] ... ) >>> volume.fig.show() ``` ![](/assets/images/api/Volume.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/Volume.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, data: tp.Optional[tp.ArrayLike] = None, x_labels: tp.Optional[tp.Labels] = None, y_labels: tp.Optional[tp.Labels] = None, z_labels: tp.Optional[tp.Labels] = None, trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, scene_name: str = "scene", make_figure_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> None: TraceType.__init__( self, data=data, x_labels=x_labels, y_labels=y_labels, z_labels=z_labels, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, scene_name=scene_name, make_figure_kwargs=make_figure_kwargs, fig=fig, **layout_kwargs, ) from vectorbtpro._settings import settings layout_cfg = settings["plotting"]["layout"] if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if data is not None: checks.assert_ndim(data, 3) data = np.asarray(data) x_len, y_len, z_len = data.shape if x_labels is not None: checks.assert_shape_equal(data, x_labels, (0, 0)) if y_labels is not None: checks.assert_shape_equal(data, y_labels, (1, 0)) if z_labels is not None: checks.assert_shape_equal(data, z_labels, (2, 0)) else: if x_labels is None or y_labels is None or z_labels is None: raise ValueError("At least data, or x_labels, y_labels and z_labels must be passed") x_len = len(x_labels) y_len = len(y_labels) z_len = len(z_labels) if x_labels is None: x_labels = np.arange(x_len) else: x_labels = clean_labels(x_labels) if y_labels is None: y_labels = np.arange(y_len) else: y_labels = clean_labels(y_labels) if z_labels is None: z_labels = np.arange(z_len) else: z_labels = clean_labels(z_labels) x_labels = np.asarray(x_labels) y_labels = np.asarray(y_labels) z_labels = np.asarray(z_labels) if fig is None: fig = make_figure(**resolve_dict(make_figure_kwargs)) if "width" in layout_cfg: # Calculate nice width and height fig.update_layout(width=layout_cfg["width"], height=0.7 * layout_cfg["width"]) # Non-numeric data types are not supported by go.Volume, so use ticktext # Note: Currently plotly displays the entire tick array, in future versions it will be more sensible more_layout = dict() more_layout[scene_name] = dict() if not np.issubdtype(x_labels.dtype, np.number): x_ticktext = x_labels x_labels = np.arange(x_len) more_layout[scene_name]["xaxis"] = dict(ticktext=x_ticktext, tickvals=x_labels, tickmode="array") if not np.issubdtype(y_labels.dtype, np.number): y_ticktext = y_labels y_labels = np.arange(y_len) more_layout[scene_name]["yaxis"] = dict(ticktext=y_ticktext, tickvals=y_labels, tickmode="array") if not np.issubdtype(z_labels.dtype, np.number): z_ticktext = z_labels z_labels = np.arange(z_len) more_layout[scene_name]["zaxis"] = dict(ticktext=z_ticktext, tickvals=z_labels, tickmode="array") fig.update_layout(**more_layout) fig.update_layout(**layout_kwargs) # Arrays must have the same length as the flattened data array x = np.repeat(x_labels, len(y_labels) * len(z_labels)) y = np.tile(np.repeat(y_labels, len(z_labels)), len(x_labels)) z = np.tile(z_labels, len(x_labels) * len(y_labels)) _trace_kwargs = merge_dicts( dict(x=x, y=y, z=z, opacity=0.2, surface_count=15, colorscale="Plasma"), trace_kwargs, ) trace = go.Volume(**_trace_kwargs) if data is not None: self.update_trace(trace, data) fig.add_trace(trace, **add_trace_kwargs) TraceUpdater.__init__(self, fig, (fig.data[-1],)) @classmethod def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, *args, **kwargs) -> None: trace.value = np.asarray(data).flatten() def update(self, data: tp.ArrayLike) -> None: with self.fig.batch_update(): self.update_trace(self.traces[0], data) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with records that can make use of OHLC data.""" from vectorbtpro import _typing as tp from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.reshaping import to_2d_array from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb from vectorbtpro.records.base import Records from vectorbtpro.records.decorators import attach_shortcut_properties from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.utils import checks from vectorbtpro.utils.config import ReadonlyConfig __all__ = [ "PriceRecords", ] __pdoc__ = {} price_records_shortcut_config = ReadonlyConfig( dict( bar_open_time=dict(obj_type="mapped"), bar_close_time=dict(obj_type="mapped"), bar_open=dict(obj_type="mapped"), bar_high=dict(obj_type="mapped"), bar_low=dict(obj_type="mapped"), bar_close=dict(obj_type="mapped"), ) ) """_""" __pdoc__[ "price_records_shortcut_config" ] = f"""Config of shortcut properties to be attached to `PriceRecords`. ```python {price_records_shortcut_config.prettify()} ``` """ PriceRecordsT = tp.TypeVar("PriceRecordsT", bound="PriceRecords") @attach_shortcut_properties(price_records_shortcut_config) class PriceRecords(Records): """Extends `vectorbtpro.records.base.Records` for records that can make use of OHLC data.""" @classmethod def from_records( cls: tp.Type[PriceRecordsT], wrapper: ArrayWrapper, records: tp.RecordArray, data: tp.Optional["Data"] = None, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, close: tp.Optional[tp.ArrayLike] = None, attach_data: bool = True, **kwargs, ) -> PriceRecordsT: """Build `PriceRecords` from records.""" if open is None and data is not None: open = data.open if high is None and data is not None: high = data.high if low is None and data is not None: low = data.low if close is None and data is not None: close = data.close return cls( wrapper, records, open=open if attach_data else None, high=high if attach_data else None, low=low if attach_data else None, close=close if attach_data else None, **kwargs, ) @classmethod def resolve_row_stack_kwargs( cls: tp.Type[PriceRecordsT], *objs: tp.MaybeTuple[PriceRecordsT], **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `PriceRecords` after stacking along columns.""" kwargs = Records.resolve_row_stack_kwargs(*objs, **kwargs) if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, PriceRecords): raise TypeError("Each object to be merged must be an instance of PriceRecords") for price_name in ("open", "high", "low", "close"): if price_name not in kwargs: price_objs = [] stack_price_objs = True for obj in objs: if getattr(obj, price_name) is not None: price_objs.append(getattr(obj, price_name)) else: stack_price_objs = False break if stack_price_objs: kwargs[price_name] = kwargs["wrapper"].row_stack_arrs( *price_objs, group_by=False, wrap=False, ) return kwargs @classmethod def resolve_column_stack_kwargs( cls: tp.Type[PriceRecordsT], *objs: tp.MaybeTuple[PriceRecordsT], reindex_kwargs: tp.KwargsLike = None, ffill_close: bool = False, fbfill_close: bool = False, **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `PriceRecords` after stacking along columns.""" kwargs = Records.resolve_column_stack_kwargs(*objs, reindex_kwargs=reindex_kwargs, **kwargs) kwargs.pop("reindex_kwargs", None) if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, PriceRecords): raise TypeError("Each object to be merged must be an instance of PriceRecords") for price_name in ("open", "high", "low", "close"): if price_name not in kwargs: price_objs = [] stack_price_objs = True for obj in objs: if getattr(obj, "_" + price_name) is not None: price_objs.append(getattr(obj, price_name)) else: stack_price_objs = False break if stack_price_objs: new_price = kwargs["wrapper"].column_stack_arrs( *price_objs, reindex_kwargs=reindex_kwargs, group_by=False, wrap=True, ) if price_name == "close": if fbfill_close: new_price = new_price.vbt.fbfill() elif ffill_close: new_price = new_price.vbt.ffill() kwargs[price_name] = new_price.values return kwargs def __init__( self, wrapper: ArrayWrapper, records_arr: tp.RecordArray, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, close: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: Records.__init__( self, wrapper, records_arr, open=open, high=high, low=low, close=close, **kwargs, ) if open is not None: open = to_2d_array(open) if high is not None: high = to_2d_array(high) if low is not None: low = to_2d_array(low) if close is not None: close = to_2d_array(close) self._open = open self._high = high self._low = low self._close = close def indexing_func_meta(self, *args, records_meta: tp.DictLike = None, **kwargs) -> dict: """Perform indexing on `PriceRecords` and return metadata.""" if records_meta is None: records_meta = Records.indexing_func_meta(self, *args, **kwargs) prices = {} for price_name in ("open", "high", "low", "close"): if getattr(self, "_" + price_name) is not None: new_price = ArrayWrapper.select_from_flex_array( getattr(self, "_" + price_name), row_idxs=records_meta["wrapper_meta"]["row_idxs"], col_idxs=records_meta["wrapper_meta"]["col_idxs"], rows_changed=records_meta["wrapper_meta"]["rows_changed"], columns_changed=records_meta["wrapper_meta"]["columns_changed"], ) else: new_price = None prices[price_name] = new_price return {**records_meta, **prices} def indexing_func(self: PriceRecordsT, *args, price_records_meta: tp.DictLike = None, **kwargs) -> PriceRecordsT: """Perform indexing on `PriceRecords`.""" if price_records_meta is None: price_records_meta = self.indexing_func_meta(*args, **kwargs) return self.replace( wrapper=price_records_meta["wrapper_meta"]["new_wrapper"], records_arr=price_records_meta["new_records_arr"], open=price_records_meta["open"], high=price_records_meta["high"], low=price_records_meta["low"], close=price_records_meta["close"], ) def resample( self: PriceRecordsT, *args, ffill_close: bool = False, fbfill_close: bool = False, records_meta: tp.DictLike = None, **kwargs, ) -> PriceRecordsT: """Perform resampling on `PriceRecords`.""" if records_meta is None: records_meta = self.resample_meta(*args, **kwargs) if self._open is None: new_open = None else: new_open = self.open.vbt.resample_apply( records_meta["wrapper_meta"]["resampler"], nb.first_reduce_nb, ) if self._high is None: new_high = None else: new_high = self.high.vbt.resample_apply( records_meta["wrapper_meta"]["resampler"], nb.max_reduce_nb, ) if self._low is None: new_low = None else: new_low = self.low.vbt.resample_apply( records_meta["wrapper_meta"]["resampler"], nb.min_reduce_nb, ) if self._close is None: new_close = None else: new_close = self.close.vbt.resample_apply( records_meta["wrapper_meta"]["resampler"], nb.last_reduce_nb, ) if fbfill_close: new_close = new_close.vbt.fbfill() elif ffill_close: new_close = new_close.vbt.ffill() return self.replace( wrapper=records_meta["wrapper_meta"]["new_wrapper"], records_arr=records_meta["new_records_arr"], open=new_open, high=new_high, low=new_low, close=new_close, ) @property def open(self) -> tp.Optional[tp.SeriesFrame]: """Open price.""" if self._open is None: return None return self.wrapper.wrap(self._open, group_by=False) @property def high(self) -> tp.Optional[tp.SeriesFrame]: """High price.""" if self._high is None: return None return self.wrapper.wrap(self._high, group_by=False) @property def low(self) -> tp.Optional[tp.SeriesFrame]: """Low price.""" if self._low is None: return None return self.wrapper.wrap(self._low, group_by=False) @property def close(self) -> tp.Optional[tp.SeriesFrame]: """Close price.""" if self._close is None: return None return self.wrapper.wrap(self._close, group_by=False) def get_bar_open_time(self, **kwargs) -> MappedArray: """Get a mapped array with the opening time of the bar.""" return self.map_array(self.wrapper.index[self.idx_arr], **kwargs) def get_bar_close_time(self, **kwargs) -> MappedArray: """Get a mapped array with the closing time of the bar.""" if self.wrapper.freq is None: raise ValueError("Must provide frequency") return self.map_array( Resampler.get_rbound_index(index=self.wrapper.index[self.idx_arr], freq=self.wrapper.freq), **kwargs ) def get_bar_open(self, **kwargs) -> MappedArray: """Get a mapped array with the opening price of the bar.""" return self.apply(nb.bar_price_nb, self._open, **kwargs) def get_bar_high(self, **kwargs) -> MappedArray: """Get a mapped array with the high price of the bar.""" return self.apply(nb.bar_price_nb, self._high, **kwargs) def get_bar_low(self, **kwargs) -> MappedArray: """Get a mapped array with the low price of the bar.""" return self.apply(nb.bar_price_nb, self._low, **kwargs) def get_bar_close(self, **kwargs) -> MappedArray: """Get a mapped array with the closing price of the bar.""" return self.apply(nb.bar_price_nb, self._close, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with range records. Range records capture information on ranges. They are useful for analyzing duration of processes, such as drawdowns, trades, and positions. They also come in handy when analyzing distance between events, such as entry and exit signals. Each range has a starting point and an ending point. For example, the points for `range(20)` are 0 and 20 (not 19!) respectively. !!! note Be aware that if a range hasn't ended in a column, its `end_idx` will point at the latest index. Make sure to account for this when computing custom metrics involving duration. ```pycon >>> from vectorbtpro import * >>> start = '2019-01-01 UTC' # crypto is in UTC >>> end = '2020-01-01 UTC' >>> price = vbt.YFData.pull('BTC-USD', start=start, end=end).get('Close') ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> fast_ma = vbt.MA.run(price, 10) >>> slow_ma = vbt.MA.run(price, 50) >>> fast_below_slow = fast_ma.ma_above(slow_ma) >>> ranges = vbt.Ranges.from_array(fast_below_slow, wrapper_kwargs=dict(freq='d')) >>> ranges.readable Range Id Column Start Timestamp End Timestamp \\ 0 0 0 2019-02-19 00:00:00+00:00 2019-07-25 00:00:00+00:00 1 1 0 2019-08-08 00:00:00+00:00 2019-08-19 00:00:00+00:00 2 2 0 2019-11-01 00:00:00+00:00 2019-11-20 00:00:00+00:00 Status 0 Closed 1 Closed 2 Closed >>> ranges.duration.max(wrap_kwargs=dict(to_timedelta=True)) Timedelta('156 days 00:00:00') ``` ## From accessors Moreover, all generic accessors have a property `ranges` and a method `get_ranges`: ```pycon >>> # vectorbtpro.generic.accessors.GenericAccessor.ranges.coverage >>> fast_below_slow.vbt.ranges.coverage 0.5081967213114754 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Ranges.metrics`. ```pycon >>> df = pd.DataFrame({ ... 'a': [1, 2, np.nan, np.nan, 5, 6], ... 'b': [np.nan, 2, np.nan, 4, np.nan, 6] ... }) >>> ranges = df.vbt(freq='d').ranges >>> ranges['a'].stats() Start 0 End 5 Period 6 days 00:00:00 Total Records 2 Coverage 0.666667 Overlap Coverage 0.0 Duration: Min 2 days 00:00:00 Duration: Median 2 days 00:00:00 Duration: Max 2 days 00:00:00 Name: a, dtype: object ``` `Ranges.stats` also supports (re-)grouping: ```pycon >>> ranges.stats(group_by=True) Start 0 End 5 Period 6 days 00:00:00 Total Records 5 Coverage 0.416667 Overlap Coverage 0.4 Duration: Min 1 days 00:00:00 Duration: Median 1 days 00:00:00 Duration: Max 2 days 00:00:00 Name: group, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Ranges.subplots`. `Ranges` class has a single subplot based on `Ranges.plot`: ```pycon >>> ranges['a'].plots().show() ``` ![](/assets/images/api/ranges_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ranges_plots.dark.svg#only-dark){: .iimg loading=lazy } """ import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.indexes import stack_indexes, combine_indexes, tile_index from vectorbtpro.base.reshaping import to_pd_array, to_1d_array, to_2d_array, tile from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb, enums from vectorbtpro.generic.price_records import PriceRecords from vectorbtpro.records.base import Records from vectorbtpro.records.decorators import override_field_config, attach_fields, attach_shortcut_properties from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING from vectorbtpro.utils.colors import adjust_lightness from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, ReadonlyConfig, HybridConfig from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.execution import Task, execute from vectorbtpro.utils.params import combine_params, Param from vectorbtpro.utils.parsing import get_func_kwargs from vectorbtpro.utils.random_ import set_seed from vectorbtpro.utils.template import substitute_templates from vectorbtpro.utils.warnings_ import warn __all__ = [ "Ranges", "PatternRanges", "PSC", ] __pdoc__ = {} # ############# Ranges ############# # ranges_field_config = ReadonlyConfig( dict( dtype=enums.range_dt, settings=dict( id=dict(title="Range Id"), idx=dict(name="end_idx"), # remap field of Records start_idx=dict(title="Start Index", mapping="index"), end_idx=dict(title="End Index", mapping="index"), status=dict(title="Status", mapping=enums.RangeStatus), ), ) ) """_""" __pdoc__[ "ranges_field_config" ] = f"""Field config for `Ranges`. ```python {ranges_field_config.prettify()} ``` """ ranges_attach_field_config = ReadonlyConfig(dict(status=dict(attach_filters=True))) """_""" __pdoc__[ "ranges_attach_field_config" ] = f"""Config of fields to be attached to `Ranges`. ```python {ranges_attach_field_config.prettify()} ``` """ ranges_shortcut_config = ReadonlyConfig( dict( valid=dict(), invalid=dict(), first_pd_mask=dict(obj_type="array"), last_pd_mask=dict(obj_type="array"), ranges_pd_mask=dict(obj_type="array"), first_idx=dict(obj_type="mapped_array"), last_idx=dict(obj_type="mapped_array"), duration=dict(obj_type="mapped_array"), real_duration=dict(obj_type="mapped_array"), avg_duration=dict(obj_type="red_array"), max_duration=dict(obj_type="red_array"), coverage=dict(obj_type="red_array"), overlap_coverage=dict(method_name="get_coverage", obj_type="red_array", method_kwargs=dict(overlapping=True)), projections=dict(obj_type="array"), ) ) """_""" __pdoc__[ "ranges_shortcut_config" ] = f"""Config of shortcut properties to be attached to `Ranges`. ```python {ranges_shortcut_config.prettify()} ``` """ RangesT = tp.TypeVar("RangesT", bound="Ranges") @attach_shortcut_properties(ranges_shortcut_config) @attach_fields(ranges_attach_field_config) @override_field_config(ranges_field_config) class Ranges(PriceRecords): """Extends `vectorbtpro.generic.price_records.PriceRecords` for working with range records. Requires `records_arr` to have all fields defined in `vectorbtpro.generic.enums.range_dt`.""" @property def field_config(self) -> Config: return self._field_config @classmethod def from_array( cls: tp.Type[RangesT], arr: tp.ArrayLike, gap_value: tp.Optional[tp.Scalar] = None, attach_as_close: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> RangesT: """Build `Ranges` from an array. Searches for sequences of * True values in boolean data (False acts as a gap), * positive values in integer data (-1 acts as a gap), and * non-NaN values in any other data (NaN acts as a gap). If `attach_as_close` is True, will attach `arr` as `close`. `**kwargs` will be passed to `Ranges.__init__`.""" if wrapper_kwargs is None: wrapper_kwargs = {} wrapper = ArrayWrapper.from_obj(arr, **wrapper_kwargs) arr = to_2d_array(arr) if gap_value is None: if np.issubdtype(arr.dtype, np.bool_): gap_value = False elif np.issubdtype(arr.dtype, np.integer): gap_value = -1 else: gap_value = np.nan func = jit_reg.resolve_option(nb.get_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) records_arr = func(arr, gap_value) if attach_as_close and "close" not in kwargs: kwargs["close"] = arr return cls(wrapper, records_arr, **kwargs) @classmethod def from_delta( cls: tp.Type[RangesT], records_or_mapped: tp.Union[Records, MappedArray], delta: tp.Union[str, int, tp.FrequencyLike], shift: tp.Optional[int] = None, idx_field_or_arr: tp.Union[None, str, tp.Array1d] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> RangesT: """Build `Ranges` from a record/mapped array with a timedelta applied on its index field. See `vectorbtpro.generic.nb.records.get_ranges_from_delta_nb`. Set `delta` to an integer to wait a certain amount of rows. Set it to anything else to wait a timedelta. The conversion is done using `vectorbtpro.utils.datetime_.to_timedelta64`. The second option requires the index to be datetime-like, or at least the frequency to be set. `**kwargs` will be passed to `Ranges.__init__`.""" if idx_field_or_arr is None: if isinstance(records_or_mapped, Records): idx_field_or_arr = records_or_mapped.get_field_arr("idx") else: idx_field_or_arr = records_or_mapped.idx_arr if isinstance(idx_field_or_arr, str): if isinstance(records_or_mapped, Records): idx_field_or_arr = records_or_mapped.get_field_arr(idx_field_or_arr) else: raise ValueError("Providing an index field is allowed for records only") if isinstance(records_or_mapped, Records): id_arr = records_or_mapped.get_field_arr("id") else: id_arr = records_or_mapped.id_arr if isinstance(delta, int): delta_use_index = False index = None else: delta = dt.to_ns(dt.to_timedelta64(delta)) if isinstance(records_or_mapped.wrapper.index, pd.DatetimeIndex): index = dt.to_ns(records_or_mapped.wrapper.index) else: freq = dt.to_ns(dt.to_timedelta64(records_or_mapped.wrapper.freq)) index = np.arange(records_or_mapped.wrapper.shape[0]) * freq delta_use_index = True if shift is None: shift = 0 col_map = records_or_mapped.col_mapper.get_col_map(group_by=False) func = jit_reg.resolve_option(nb.get_ranges_from_delta_nb, jitted) func = ch_reg.resolve_option(func, chunked) new_records_arr = func( records_or_mapped.wrapper.shape[0], idx_field_or_arr, id_arr, col_map, index=index, delta=delta, delta_use_index=delta_use_index, shift=shift, ) if isinstance(records_or_mapped, PriceRecords): kwargs = merge_dicts( dict( open=records_or_mapped._open, high=records_or_mapped._high, low=records_or_mapped._low, close=records_or_mapped._close, ), kwargs, ) return Ranges.from_records(records_or_mapped.wrapper, new_records_arr, **kwargs) def with_delta(self, *args, **kwargs): """Pass self to `Ranges.from_delta`.""" return Ranges.from_delta(self, *args, **kwargs) def crop(self) -> RangesT: """Remove any data outside the minimum start index and the maximum end index.""" min_start_idx = np.min(self.get_field_arr("start_idx")) max_start_idx = np.max(self.get_field_arr("end_idx")) + 1 return self.iloc[min_start_idx:max_start_idx] # ############# Filtering ############# # def filter_min_duration( self: RangesT, min_duration: tp.Union[str, int, tp.FrequencyLike], real: bool = False, **kwargs, ) -> RangesT: """Filter out ranges that last less than a minimum duration.""" if isinstance(min_duration, int): return self.apply_mask(self.duration.values >= min_duration, **kwargs) min_duration = dt.to_timedelta64(min_duration) if real: return self.apply_mask(self.real_duration.values >= min_duration, **kwargs) return self.apply_mask(self.duration.values * self.wrapper.freq >= min_duration, **kwargs) def filter_max_duration( self: RangesT, max_duration: tp.Union[str, int, tp.FrequencyLike], real: bool = False, **kwargs, ) -> RangesT: """Filter out ranges that last more than a maximum duration.""" if isinstance(max_duration, int): return self.apply_mask(self.duration.values <= max_duration, **kwargs) max_duration = dt.to_timedelta64(max_duration) if real: return self.apply_mask(self.real_duration.values <= max_duration, **kwargs) return self.apply_mask(self.duration.values * self.wrapper.freq <= max_duration, **kwargs) # ############# Masking ############# # def get_first_pd_mask(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: """Get mask from `Ranges.get_first_idx`.""" return self.get_pd_mask(idx_arr=self.first_idx.values, group_by=group_by, **kwargs) def get_last_pd_mask(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame: """Get mask from `Ranges.get_last_idx`.""" out = self.get_pd_mask(idx_arr=self.last_idx.values, group_by=group_by, **kwargs) return out def get_ranges_pd_mask( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get mask from ranges. See `vectorbtpro.generic.nb.records.ranges_to_mask_nb`.""" col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.ranges_to_mask_nb, jitted) func = ch_reg.resolve_option(func, chunked) mask = func( self.get_field_arr("start_idx"), self.get_field_arr("end_idx"), self.get_field_arr("status"), col_map, len(self.wrapper.index), ) return self.wrapper.wrap(mask, group_by=group_by, **resolve_dict(wrap_kwargs)) # ############# Stats ############# # def get_valid(self: RangesT, **kwargs) -> RangesT: """Get valid ranges. A valid range doesn't have the start and end index set to -1.""" filter_mask = (self.get_field_arr("start_idx") != -1) & (self.get_field_arr("end_idx") != -1) return self.apply_mask(filter_mask, **kwargs) def get_invalid(self: RangesT, **kwargs) -> RangesT: """Get invalid ranges. An invalid range has the start and/or end index set to -1.""" filter_mask = (self.get_field_arr("start_idx") == -1) | (self.get_field_arr("end_idx") == -1) return self.apply_mask(filter_mask, **kwargs) def get_first_idx(self, **kwargs): """Get the first index in each range.""" return self.map_field("start_idx", **kwargs) def get_last_idx(self, **kwargs): """Get the last index in each range.""" last_idx = self.get_field_arr("end_idx", copy=True) status = self.get_field_arr("status") last_idx[status == enums.RangeStatus.Closed] -= 1 return self.map_array(last_idx, **kwargs) def get_duration( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """Get the effective duration of each range in integer format.""" func = jit_reg.resolve_option(nb.range_duration_nb, jitted) func = ch_reg.resolve_option(func, chunked) duration = func( self.get_field_arr("start_idx"), self.get_field_arr("end_idx"), self.get_field_arr("status"), freq=1, ) return self.map_array(duration, **kwargs) def get_real_duration( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """Get the real duration of each range in timedelta format.""" func = jit_reg.resolve_option(nb.range_duration_nb, jitted) func = ch_reg.resolve_option(func, chunked) duration = func( dt.to_ns(self.get_map_field_to_index("start_idx")), dt.to_ns(self.get_map_field_to_index("end_idx")), self.get_field_arr("status"), freq=dt.to_ns(dt.to_timedelta64(self.wrapper.freq)), ).astype("timedelta64[ns]") return self.map_array(duration, **kwargs) def get_avg_duration( self, real: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get average range duration (as timedelta).""" if real: duration = self.real_duration duration = duration.replace(mapped_arr=dt.to_ns(duration.mapped_arr)) wrap_kwargs = merge_dicts(dict(name_or_index="avg_real_duration", dtype="timedelta64[ns]"), wrap_kwargs) else: duration = self.duration wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="avg_duration"), wrap_kwargs) return duration.mean(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs) def get_max_duration( self, real: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get maximum range duration (as timedelta).""" if real: duration = self.real_duration duration = duration.replace(mapped_arr=dt.to_ns(duration.mapped_arr)) wrap_kwargs = merge_dicts(dict(name_or_index="max_real_duration", dtype="timedelta64[ns]"), wrap_kwargs) else: duration = self.duration wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="max_duration"), wrap_kwargs) return duration.max(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs) def get_coverage( self, overlapping: bool = False, normalize: bool = True, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get coverage, that is, the number of steps that are covered by all ranges. See `vectorbtpro.generic.nb.records.range_coverage_nb`.""" col_map = self.col_mapper.get_col_map(group_by=group_by) index_lens = self.wrapper.grouper.get_group_lens(group_by=group_by) * self.wrapper.shape[0] func = jit_reg.resolve_option(nb.range_coverage_nb, jitted) func = ch_reg.resolve_option(func, chunked) coverage = func( self.get_field_arr("start_idx"), self.get_field_arr("end_idx"), self.get_field_arr("status"), col_map, index_lens, overlapping=overlapping, normalize=normalize, ) wrap_kwargs = merge_dicts(dict(name_or_index="coverage"), wrap_kwargs) return self.wrapper.wrap_reduced(coverage, group_by=group_by, **wrap_kwargs) def get_projections( self, close: tp.Optional[tp.ArrayLike] = None, proj_start: tp.Union[None, str, int, tp.FrequencyLike] = None, proj_period: tp.Union[None, str, int, tp.FrequencyLike] = None, incl_end_idx: bool = True, extend: bool = False, rebase: bool = True, start_value: tp.ArrayLike = 1.0, ffill: bool = False, remove_empty: bool = True, return_raw: bool = False, start_index: tp.Optional[pd.Timestamp] = None, id_level: tp.Union[None, str, tp.IndexLike] = None, jitted: tp.JittedOption = None, wrap_kwargs: tp.KwargsLike = None, clean_index_kwargs: tp.KwargsLike = None, ) -> tp.Union[tp.Tuple[tp.Array1d, tp.Array2d], tp.Frame]: """Generate a projection for each range record. See `vectorbtpro.generic.nb.records.map_ranges_to_projections_nb`. Set `proj_start` to an integer to generate a projection after a certain row after the start row. Set it to anything else to wait a timedelta. The conversion is done using `vectorbtpro.utils.datetime_.to_timedelta64`. The second option requires the index to be datetime-like, or at least the frequency to be set. Set `proj_period` the same way as `proj_start` to generate a projection of a certain length. Unless `extend` is True, it still respects the duration of the range. Set `extend` to True to extend the projection even after the end of the range. The extending period is taken from the longest range duration if `proj_period` is None, and from the longest `proj_period` if it's not None. Set `rebase` to True to make each projection start with 1, otherwise, each projection will consist of original `close` values during the projected period. Use `start_value` to replace 1 with another start value. It can also be a flexible array with elements per column. If `start_value` is -1, will set it to the latest row in `close`. Set `ffill` to True to forward fill NaN values, even if they are NaN in `close` itself. Set `remove_empty` to True to remove projections that are either NaN or with only one element. The index of each projection is still being tracked and will appear in the multi-index of the returned DataFrame. !!! note As opposed to the Numba-compiled function, the returned DataFrame will have projections stacked along columns rather than rows. Set `return_raw` to True to return them in the original format. """ if close is None: close = self.close checks.assert_not_none(close, arg_name="close") else: close = self.wrapper.wrap(close, group_by=False) if proj_start is None: proj_start = 0 if isinstance(proj_start, int): proj_start_use_index = False index = None else: proj_start = dt.to_ns(dt.to_timedelta64(proj_start)) if isinstance(self.wrapper.index, pd.DatetimeIndex): index = dt.to_ns(self.wrapper.index) else: freq = dt.to_ns(dt.to_timedelta64(self.wrapper.freq)) index = np.arange(self.wrapper.shape[0]) * freq proj_start_use_index = True if proj_period is not None: if isinstance(proj_period, int): proj_period_use_index = False else: proj_period = dt.to_ns(dt.to_timedelta64(proj_period)) if index is None: if isinstance(self.wrapper.index, pd.DatetimeIndex): index = dt.to_ns(self.wrapper.index) else: freq = dt.to_ns(dt.to_timedelta64(self.wrapper.freq)) index = np.arange(self.wrapper.shape[0]) * freq proj_period_use_index = True else: proj_period_use_index = False func = jit_reg.resolve_option(nb.map_ranges_to_projections_nb, jitted) ridxs, projections = func( to_2d_array(close), self.get_field_arr("col"), self.get_field_arr("start_idx"), self.get_field_arr("end_idx"), self.get_field_arr("status"), index=index, proj_start=proj_start, proj_start_use_index=proj_start_use_index, proj_period=proj_period, proj_period_use_index=proj_period_use_index, incl_end_idx=incl_end_idx, extend=extend, rebase=rebase, start_value=to_1d_array(start_value), ffill=ffill, remove_empty=remove_empty, ) if return_raw: return ridxs, projections projections = projections.T wrapper = ArrayWrapper.from_obj(projections, freq=self.wrapper.freq) if id_level is None: id_level = pd.Index(self.id_arr, name="range_id") elif isinstance(id_level, str): mapping = self.get_field_mapping(id_level) if isinstance(mapping, str) and mapping == "index": id_level = self.get_map_field_to_index(id_level).rename(id_level) else: id_level = pd.Index(self.get_apply_mapping_arr(id_level), name=id_level) else: if not isinstance(id_level, pd.Index): id_level = pd.Index(id_level, name="range_id") if start_index is None: start_index = close.index[-1] wrap_kwargs = merge_dicts( dict( index=pd.date_range( start=start_index, periods=projections.shape[0], freq=self.wrapper.freq, ), columns=stack_indexes( self.wrapper.columns[self.col_arr[ridxs]], id_level[ridxs], **resolve_dict(clean_index_kwargs), ), ), wrap_kwargs, ) return wrapper.wrap(projections, **wrap_kwargs) @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Ranges.stats`. Merges `vectorbtpro.records.base.Records.stats_defaults` and `stats` from `vectorbtpro._settings.ranges`.""" from vectorbtpro._settings import settings ranges_stats_cfg = settings["ranges"]["stats"] return merge_dicts(Records.stats_defaults.__get__(self), ranges_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), total_records=dict(title="Total Records", calc_func="count", tags="records"), coverage=dict( title="Coverage", calc_func="coverage", overlapping=False, tags=["ranges", "coverage"], ), overlap_coverage=dict( title="Overlap Coverage", calc_func="coverage", overlapping=True, tags=["ranges", "coverage"], ), duration=dict( title="Duration", calc_func="duration.describe", post_calc_func=lambda self, out, settings: { "Min": out.loc["min"], "Median": out.loc["50%"], "Max": out.loc["max"], }, apply_to_timedelta=True, tags=["ranges", "duration"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot_projections( self, column: tp.Optional[tp.Label] = None, min_duration: tp.Union[str, int, tp.FrequencyLike] = None, max_duration: tp.Union[str, int, tp.FrequencyLike] = None, last_n: tp.Optional[int] = None, top_n: tp.Optional[int] = None, random_n: tp.Optional[int] = None, seed: tp.Optional[int] = None, proj_start: tp.Union[None, str, int, tp.FrequencyLike] = "current_or_0", proj_period: tp.Union[None, str, int, tp.FrequencyLike] = "max", incl_end_idx: bool = True, extend: bool = False, ffill: bool = False, plot_past_period: tp.Union[None, str, int, tp.FrequencyLike] = "current_or_proj_period", plot_ohlc: tp.Union[bool, tp.Frame] = True, plot_close: tp.Union[bool, tp.Series] = True, plot_projections: bool = True, plot_bands: bool = True, plot_lower: tp.Union[bool, str, tp.Callable] = True, plot_middle: tp.Union[bool, str, tp.Callable] = True, plot_upper: tp.Union[bool, str, tp.Callable] = True, plot_aux_middle: tp.Union[bool, str, tp.Callable] = True, plot_fill: bool = True, colorize: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, projection_trace_kwargs: tp.KwargsLike = None, lower_trace_kwargs: tp.KwargsLike = None, middle_trace_kwargs: tp.KwargsLike = None, upper_trace_kwargs: tp.KwargsLike = None, aux_middle_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot projections. Combines generation of projections using `Ranges.get_projections` and their plotting using `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. Args: column (str): Name of the column to plot. min_duration (str, int, or frequency_like): Filter range records by minimum duration. max_duration (str, int, or frequency_like): Filter range records by maximum duration. last_n (int): Select last N range records. top_n (int): Select top N range records by maximum duration. random_n (int): Select N range records randomly. seed (int): Seed to make output deterministic. proj_start (str, int, or frequency_like): See `Ranges.get_projections`. Allows an additional option "current_or_{value}", which sets `proj_start` to the duration of the current open range, and to the specified value if there is no open range. proj_period (str, int, or frequency_like): See `Ranges.get_projections`. Allows additional options "current_or_{option}", "mean", "min", "max", "median", or a percentage such as "50%" representing a quantile. All of those options are based on the duration of all the closed ranges filtered by the arguments above. incl_end_idx (bool): See `Ranges.get_projections`. extend (bool): See `Ranges.get_projections`. ffill (bool): See `Ranges.get_projections`. plot_past_period (str, int, or frequency_like): Past period to plot. Allows the same options as `proj_period` plus "proj_period" and "current_or_proj_period". plot_ohlc (bool or DataFrame): Whether to plot OHLC. plot_close (bool or Series): Whether to plot close. plot_projections (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. plot_bands (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. plot_lower (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. plot_middle (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. plot_upper (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. plot_aux_middle (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. plot_fill (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. colorize (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.close`. projection_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for projections. lower_trace_kwargs (dict): Keyword arguments passed to `plotly.plotly.graph_objects.Scatter` for lower band. middle_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for middle band. upper_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for upper band. aux_middle_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for auxiliary middle band. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> price = pd.Series( ... [11, 12, 13, 14, 11, 12, 13, 12, 11, 12], ... index=pd.date_range("2020", periods=10), ... ) >>> vbt.Ranges.from_array( ... price >= 12, ... attach_as_close=False, ... close=price, ... ).plot_projections( ... proj_start=0, ... proj_period=4, ... extend=True, ... plot_past_period=None ... ).show() ``` ![](/assets/images/api/ranges_plot_projections.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ranges_plot_projections.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) self_col_open = self_col.status_open self_col = self_col.status_closed if proj_start is not None: if isinstance(proj_start, str) and proj_start.startswith("current_or_"): if self_col_open.count() > 0: if self_col_open.count() > 1: raise ValueError("Only one open range is allowed") proj_start = int(self_col_open.duration.values[0]) else: proj_start = proj_start.replace("current_or_", "") if proj_start.isnumeric(): proj_start = int(proj_start) if proj_start != 0: self_col = self_col.filter_min_duration(proj_start, real=True) if min_duration is not None: self_col = self_col.filter_min_duration(min_duration, real=True) if max_duration is not None: self_col = self_col.filter_max_duration(max_duration, real=True) if last_n is not None: self_col = self_col.last_n(last_n) if top_n is not None: self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n)) if random_n is not None: self_col = self_col.random_n(random_n, seed=seed) if self_col.count() == 0: warn("No ranges to plot. Relax the requirements.") if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if isinstance(plot_ohlc, bool): if ( self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) else: ohlc = None else: ohlc = plot_ohlc plot_ohlc = True if isinstance(plot_close, bool): if ohlc is not None: close = ohlc.vbt.ohlcv.close else: close = self_col.close else: close = plot_close plot_close = True if close is None: raise ValueError("Close cannot be None") # Resolve windows def _resolve_period(period): if self_col.count() == 0: period = None if period is not None: if isinstance(period, str): period = period.lower().replace(" ", "") if period == "median": period = "50%" if "%" in period: period = int( np.quantile( self_col.duration.values, float(period.replace("%", "")) / 100, ) ) elif period.startswith("current_or_"): if self_col_open.count() > 0: if self_col_open.count() > 1: raise ValueError("Only one open range is allowed") period = int(self_col_open.duration.values[0]) else: period = period.replace("current_or_", "") return _resolve_period(period) elif period == "mean": period = int(np.mean(self_col.duration.values)) elif period == "min": period = int(np.min(self_col.duration.values)) elif period == "max": period = int(np.max(self_col.duration.values)) return period proj_period = _resolve_period(proj_period) if isinstance(proj_period, int) and proj_period == 0: warn("Projection period is zero. Setting to maximum.") proj_period = int(np.max(self_col.duration.values)) if plot_past_period is not None and isinstance(plot_past_period, str): plot_past_period = plot_past_period.lower().replace(" ", "") if plot_past_period == "proj_period": plot_past_period = proj_period elif plot_past_period == "current_or_proj_period": if self_col_open.count() > 0: if self_col_open.count() > 1: raise ValueError("Only one open range is allowed") plot_past_period = int(self_col_open.duration.values[0]) else: plot_past_period = proj_period plot_past_period = _resolve_period(plot_past_period) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) # Plot OHLC/close if plot_ohlc and ohlc is not None: if plot_past_period is not None: if isinstance(plot_past_period, int): _ohlc = ohlc.iloc[-plot_past_period:] else: plot_past_period = dt.to_timedelta(plot_past_period) _ohlc = ohlc[ohlc.index > ohlc.index[-1] - plot_past_period] else: _ohlc = ohlc if _ohlc.size > 0: if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = _ohlc.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close: if plot_past_period is not None: if isinstance(plot_past_period, int): _close = close.iloc[-plot_past_period:] else: plot_past_period = dt.to_timedelta(plot_past_period) _close = close[close.index > close.index[-1] - plot_past_period] else: _close = close if _close.size > 0: fig = _close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Get projections projections = self_col.get_projections( close=close, proj_start=proj_start, proj_period=proj_period, incl_end_idx=incl_end_idx, extend=extend, rebase=True, start_value=close.iloc[-1], ffill=ffill, remove_empty=True, return_raw=False, ) if len(projections.columns) > 0: # Plot projections rename_levels = dict(range_id=self_col.get_field_title("id")) fig = projections.vbt.plot_projections( plot_projections=plot_projections, plot_bands=plot_bands, plot_lower=plot_lower, plot_middle=plot_middle, plot_upper=plot_upper, plot_aux_middle=plot_aux_middle, plot_fill=plot_fill, colorize=colorize, rename_levels=rename_levels, projection_trace_kwargs=projection_trace_kwargs, upper_trace_kwargs=upper_trace_kwargs, middle_trace_kwargs=middle_trace_kwargs, lower_trace_kwargs=lower_trace_kwargs, aux_middle_trace_kwargs=aux_middle_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig def plot_shapes( self, column: tp.Optional[tp.Label] = None, plot_ohlc: tp.Union[bool, tp.Frame] = True, plot_close: tp.Union[bool, tp.Series] = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot range shapes. Args: column (str): Name of the column to plot. plot_ohlc (bool or DataFrame): Whether to plot OHLC. plot_close (bool or Series): Whether to plot close. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.close`. shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for shapes. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: * Plot zones colored by duration: ```pycon >>> price = pd.Series( ... [1, 2, 1, 2, 3, 2, 1, 2, 3], ... index=pd.date_range("2020", periods=9), ... ) >>> def get_opacity(self_col, i): ... real_duration = self_col.get_real_duration().values ... return real_duration[i] / real_duration.max() * 0.5 >>> vbt.Ranges.from_array(price >= 2).plot_shapes( ... shape_kwargs=dict(fillcolor="teal", opacity=vbt.RepFunc(get_opacity)) ... ).show() ``` ![](/assets/images/api/ranges_plot_shapes.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ranges_plot_shapes.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure, get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if shape_kwargs is None: shape_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if isinstance(plot_ohlc, bool): if ( self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) else: ohlc = None else: ohlc = plot_ohlc plot_ohlc = True if isinstance(plot_close, bool): if ohlc is not None: close = ohlc.vbt.ohlcv.close else: close = self_col.close else: close = plot_close plot_close = True if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) x_domain = get_domain(yref, fig) y_domain = get_domain(yref, fig) # Plot OHLC/close if plot_ohlc and ohlc is not None: if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close and close is not None: fig = close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: start_idx = self_col.get_map_field_to_index("start_idx", minus_one_to_zero=True) end_idx = self_col.get_map_field_to_index("end_idx") for i in range(len(self_col.values)): start_index = start_idx[i] end_index = end_idx[i] _shape_kwargs = substitute_templates( shape_kwargs, context=dict( self_col=self_col, i=i, record=self_col.values[i], start_index=start_index, end_index=end_index, xref=xref, yref=yref, x_domain=x_domain, y_domain=y_domain, close=close, ohlc=ohlc, ), eval_id="shape_kwargs", ) _shape_kwargs = merge_dicts( dict( type="rect", xref=xref, yref="paper", x0=start_index, y0=y_domain[0], x1=end_index, y1=y_domain[1], fillcolor="gray", opacity=0.15, layer="below", line_width=0, ), _shape_kwargs, ) fig.add_shape(**_shape_kwargs) return fig def plot( self, column: tp.Optional[tp.Label] = None, top_n: tp.Optional[int] = None, plot_ohlc: tp.Union[bool, tp.Frame] = True, plot_close: tp.Union[bool, tp.Series] = True, plot_markers: bool = True, plot_zones: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, start_trace_kwargs: tp.KwargsLike = None, end_trace_kwargs: tp.KwargsLike = None, open_shape_kwargs: tp.KwargsLike = None, closed_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, return_close: bool = False, **layout_kwargs, ) -> tp.Union[tp.BaseFigure, tp.Tuple[tp.BaseFigure, tp.Series]]: """Plot ranges. Args: column (str): Name of the column to plot. top_n (int): Filter top N range records by maximum duration. plot_ohlc (bool or DataFrame): Whether to plot OHLC. plot_close (bool or Series): Whether to plot close. plot_markers (bool): Whether to plot markers. plot_zones (bool): Whether to plot zones. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.close`. start_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for start values. end_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for end values. open_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for open zones. closed_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for closed zones. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. return_close (bool): Whether to return the close series along with the figure. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> price = pd.Series( ... [1, 2, 1, 2, 3, 2, 1, 2, 3], ... index=pd.date_range("2020", periods=9), ... ) >>> vbt.Ranges.from_array(price >= 2).plot().show() ``` ![](/assets/images/api/ranges_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ranges_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure, get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if top_n is not None: self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n)) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if start_trace_kwargs is None: start_trace_kwargs = {} if end_trace_kwargs is None: end_trace_kwargs = {} if open_shape_kwargs is None: open_shape_kwargs = {} if closed_shape_kwargs is None: closed_shape_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if isinstance(plot_ohlc, bool): if ( self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) else: ohlc = None else: ohlc = plot_ohlc plot_ohlc = True if isinstance(plot_close, bool): if ohlc is not None: close = ohlc.vbt.ohlcv.close else: close = self_col.close else: close = plot_close plot_close = True if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) y_domain = get_domain(yref, fig) # Plot OHLC/close plotting_ohlc = False if plot_ohlc and ohlc is not None: if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) plotting_ohlc = True elif plot_close and close is not None: fig = close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Extract information start_idx = self_col.get_map_field_to_index("start_idx", minus_one_to_zero=True) if plotting_ohlc and self_col.open is not None: start_val = self_col.open.loc[start_idx] elif close is not None: start_val = close.loc[start_idx] else: start_val = np.full(len(start_idx), 0) end_idx = self_col.get_map_field_to_index("end_idx") if close is not None: end_val = close.loc[end_idx] else: end_val = np.full(len(end_idx), 0) status = self_col.get_field_arr("status") if plot_markers: # Plot start markers start_customdata, start_hovertemplate = self_col.prepare_customdata(incl_fields=["id", "start_idx"]) _start_trace_kwargs = merge_dicts( dict( x=start_idx, y=start_val, mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["blue"], size=7, line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["blue"])), ), name="Start", customdata=start_customdata, hovertemplate=start_hovertemplate, ), start_trace_kwargs, ) start_scatter = go.Scatter(**_start_trace_kwargs) fig.add_trace(start_scatter, **add_trace_kwargs) closed_mask = status == enums.RangeStatus.Closed if closed_mask.any(): if plot_markers: # Plot end markers closed_end_customdata, closed_end_hovertemplate = self_col.prepare_customdata(mask=closed_mask) _end_trace_kwargs = merge_dicts( dict( x=end_idx[closed_mask], y=end_val[closed_mask], mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["green"], size=7, line=dict( width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["green"]) ), ), name="Closed", customdata=closed_end_customdata, hovertemplate=closed_end_hovertemplate, ), end_trace_kwargs, ) closed_end_scatter = go.Scatter(**_end_trace_kwargs) fig.add_trace(closed_end_scatter, **add_trace_kwargs) open_mask = status == enums.RangeStatus.Open if open_mask.any(): if plot_markers: # Plot end markers open_end_customdata, open_end_hovertemplate = self_col.prepare_customdata( excl_fields=["end_idx"], mask=open_mask ) _end_trace_kwargs = merge_dicts( dict( x=end_idx[open_mask], y=end_val[open_mask], mode="markers", marker=dict( symbol="diamond", color=plotting_cfg["contrast_color_schema"]["orange"], size=7, line=dict( width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["orange"]) ), ), name="Open", customdata=open_end_customdata, hovertemplate=open_end_hovertemplate, ), end_trace_kwargs, ) open_end_scatter = go.Scatter(**_end_trace_kwargs) fig.add_trace(open_end_scatter, **add_trace_kwargs) if plot_zones: # Plot closed range zones self_col.status_closed.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict(fillcolor=plotting_cfg["contrast_color_schema"]["green"]), closed_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) # Plot open range zones self_col.status_open.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict(fillcolor=plotting_cfg["contrast_color_schema"]["orange"]), open_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) if return_close: return fig, close return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Ranges.plots`. Merges `vectorbtpro.records.base.Records.plots_defaults` and `plots` from `vectorbtpro._settings.ranges`.""" from vectorbtpro._settings import settings ranges_plots_cfg = settings["ranges"]["plots"] return merge_dicts(Records.plots_defaults.__get__(self), ranges_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( title="Ranges", check_is_not_grouped=True, plot_func="plot", tags="ranges", ) ), ) @property def subplots(self) -> Config: return self._subplots Ranges.override_field_config_doc(__pdoc__) Ranges.override_metrics_doc(__pdoc__) Ranges.override_subplots_doc(__pdoc__) # ############# Pattern ranges ############# # PatternRangesT = tp.TypeVar("PatternRangesT", bound="PatternRanges") @define class PSC(DefineMixin): """Class that represents a pattern search config. Every field will be resolved into the format suitable for Numba.""" pattern: tp.Union[tp.ArrayLike] = define.required_field() """Flexible pattern array. Can be smaller or bigger than the source array; in such a case, the values of the smaller array will be "stretched" by interpolation of the type in `PSC.interp_mode`.""" window: tp.Optional[int] = define.optional_field() """Minimum window. Defaults to the length of `PSC.pattern`.""" max_window: tp.Optional[int] = define.optional_field() """Maximum window (including).""" row_select_prob: tp.Union[float] = define.optional_field() """Row selection probability.""" window_select_prob: tp.Union[float] = define.optional_field() """Window selection probability.""" roll_forward: tp.Union[bool] = define.optional_field() """Whether to roll windows to the left of the current row, otherwise to the right.""" interp_mode: tp.Union[int, str] = define.optional_field() """Interpolation mode. See `vectorbtpro.generic.enums.InterpMode`.""" rescale_mode: tp.Union[int, str] = define.optional_field() """Rescaling mode. See `vectorbtpro.generic.enums.RescaleMode`.""" vmin: tp.Union[float] = define.optional_field() """Minimum value of any window. Should only be used when the array has fixed bounds. Used in rescaling using `RescaleMode.MinMax` and checking against `PSC.min_pct_change` and `PSC.max_pct_change`. If `np.nan`, gets calculated dynamically.""" vmax: tp.Union[float] = define.optional_field() """Maximum value of any window. Should only be used when the array has fixed bounds. Used in rescaling using `RescaleMode.MinMax` and checking against `PSC.min_pct_change` and `PSC.max_pct_change`. If `np.nan`, gets calculated dynamically.""" pmin: tp.Union[float] = define.optional_field() """Value to be considered as the minimum of `PSC.pattern`. Used in rescaling using `RescaleMode.MinMax` and calculating the maximum distance at each point if `PSC.max_error_as_maxdist` is disabled. If `np.nan`, gets calculated dynamically.""" pmax: tp.Union[float] = define.optional_field() """Value to be considered as the maximum of `PSC.pattern`. Used in rescaling using `RescaleMode.MinMax` and calculating the maximum distance at each point if `PSC.max_error_as_maxdist` is disabled. If `np.nan`, gets calculated dynamically.""" invert: tp.Union[bool] = define.optional_field() """Whether to invert the pattern vertically.""" error_type: tp.Union[int, str] = define.optional_field() """Error type. See `vectorbtpro.generic.enums.ErrorType`.""" distance_measure: tp.Union[int, str] = define.optional_field() """Distance measure. See `vectorbtpro.generic.enums.DistanceMeasure`.""" max_error: tp.Union[tp.ArrayLike] = define.optional_field() """Maximum error at each point. Can be provided as a flexible array. If `max_error` is an array, it must be of the same size as the pattern array. It also should be provided within the same scale as the pattern.""" max_error_interp_mode: tp.Union[None, int, str] = define.optional_field() """Interpolation mode for `PSC.max_error`. See `vectorbtpro.generic.enums.InterpMode`. If None, defaults to `PSC.interp_mode`.""" max_error_as_maxdist: tp.Union[bool] = define.optional_field() """Whether `PSC.max_error` should be used as the maximum distance at each point. If False, crossing `PSC.max_error` will set the distance to the maximum distance based on `PSC.pmin`, `PSC.pmax`, and the pattern value at that point. If True and any of the points in a window is `np.nan`, the point will be skipped.""" max_error_strict: tp.Union[bool] = define.optional_field() """Whether crossing `PSC.max_error` even once should yield the similarity of `np.nan`.""" min_pct_change: tp.Union[float] = define.optional_field() """Minimum percentage change of the window to stay a candidate for search. If any window doesn't cross this mark, its similarity becomes `np.nan`.""" max_pct_change: tp.Union[float] = define.optional_field() """Maximum percentage change of the window to stay a candidate for search. If any window crosses this mark, its similarity becomes `np.nan`.""" min_similarity: tp.Union[float] = define.optional_field() """Minimum similarity. If any window doesn't cross this mark, its similarity becomes `np.nan`.""" minp: tp.Optional[int] = define.optional_field() """Minimum number of observations in price window required to have a value.""" overlap_mode: tp.Union[int, str] = define.optional_field() """Overlapping mode. See `vectorbtpro.generic.enums.OverlapMode`.""" max_records: tp.Optional[int] = define.optional_field() """Maximum number of records expected to be filled. Set to avoid creating empty arrays larger than needed.""" name: tp.Optional[str] = define.field(default=None) """Name of the config.""" def __eq__(self, other): return checks.is_deep_equal(self, other) def __hash__(self): dct = self.asdict() if isinstance(dct["pattern"], np.ndarray): dct["pattern"] = tuple(dct["pattern"]) else: dct["pattern"] = (dct["pattern"],) if isinstance(dct["max_error"], np.ndarray): dct["max_error"] = tuple(dct["max_error"]) else: dct["max_error"] = (dct["max_error"],) return hash(tuple(dct.items())) pattern_ranges_field_config = ReadonlyConfig( dict( dtype=enums.pattern_range_dt, settings=dict( id=dict(title="Pattern Range Id"), similarity=dict(title="Similarity"), ), ) ) """_""" __pdoc__[ "pattern_ranges_field_config" ] = f"""Field config for `PatternRanges`. ```python {pattern_ranges_field_config.prettify()} ``` """ @attach_fields @override_field_config(pattern_ranges_field_config) class PatternRanges(Ranges): """Extends `Ranges` for working with range records generated from pattern search.""" @property def field_config(self) -> Config: return self._field_config @classmethod def resolve_search_config(cls, search_config: tp.Union[None, dict, PSC] = None, **kwargs) -> PSC: """Resolve search config for `PatternRanges.from_pattern_search`. Converts array-like objects into arrays and enums into integers.""" if search_config is None: search_config = dict() if isinstance(search_config, dict): search_config = PSC(**search_config) search_config = search_config.asdict() defaults = {} for k, v in get_func_kwargs(cls.from_pattern_search).items(): if k in search_config: defaults[k] = v defaults = merge_dicts(defaults, kwargs) for k, v in search_config.items(): if v is MISSING: v = defaults[k] if k == "pattern": if v is None: raise ValueError("Must provide pattern") v = to_1d_array(v) elif k == "max_error": v = to_1d_array(v) elif k == "interp_mode": v = map_enum_fields(v, enums.InterpMode) elif k == "rescale_mode": v = map_enum_fields(v, enums.RescaleMode) elif k == "error_type": v = map_enum_fields(v, enums.ErrorType) elif k == "distance_measure": v = map_enum_fields(v, enums.DistanceMeasure) elif k == "max_error_interp_mode": if v is None: v = search_config["interp_mode"] else: v = map_enum_fields(v, enums.InterpMode) elif k == "overlap_mode": v = map_enum_fields(v, enums.OverlapMode) search_config[k] = v return PSC(**search_config) @classmethod def from_pattern_search( cls: tp.Type[PatternRangesT], arr: tp.ArrayLike, pattern: tp.Union[Param, tp.ArrayLike] = None, window: tp.Union[Param, None, int] = None, max_window: tp.Union[Param, None, int] = None, row_select_prob: tp.Union[Param, float] = 1.0, window_select_prob: tp.Union[Param, float] = 1.0, roll_forward: tp.Union[Param, bool] = False, interp_mode: tp.Union[Param, int, str] = "mixed", rescale_mode: tp.Union[Param, int, str] = "minmax", vmin: tp.Union[Param, float] = np.nan, vmax: tp.Union[Param, float] = np.nan, pmin: tp.Union[Param, float] = np.nan, pmax: tp.Union[Param, float] = np.nan, invert: bool = False, error_type: tp.Union[Param, int, str] = "absolute", distance_measure: tp.Union[Param, int, str] = "mae", max_error: tp.Union[Param, tp.ArrayLike] = np.nan, max_error_interp_mode: tp.Union[Param, None, int, str] = None, max_error_as_maxdist: tp.Union[Param, bool] = False, max_error_strict: tp.Union[Param, bool] = False, min_pct_change: tp.Union[Param, float] = np.nan, max_pct_change: tp.Union[Param, float] = np.nan, min_similarity: tp.Union[Param, float] = 0.85, minp: tp.Union[Param, None, int] = None, overlap_mode: tp.Union[Param, int, str] = "disallow", max_records: tp.Union[Param, None, int] = None, random_subset: tp.Optional[int] = None, seed: tp.Optional[int] = None, search_configs: tp.Optional[tp.Sequence[tp.MaybeSequence[PSC]]] = None, jitted: tp.JittedOption = None, execute_kwargs: tp.KwargsLike = None, attach_as_close: bool = True, clean_index_kwargs: tp.KwargsLike = None, wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> PatternRangesT: """Build `PatternRanges` from all occurrences of a pattern in an array. Searches for parameters of the type `vectorbtpro.utils.params.Param`, and if found, broadcasts and combines them using `vectorbtpro.utils.params.combine_params`. Then, converts them into a list of search configurations. If none of such parameters was found among the passed arguments, builds one search configuration using the passed arguments. If `search_configs` is not None, uses it instead. In all cases, it uses the defaults defined in the signature of this method to augment search configurations. For example, passing `min_similarity` of 95% will use it in all search configurations except where it was explicitly overridden. Argument `search_configs` must be provided as a sequence of `PSC` instances. If any element is a list of `PSC` instances itself, it will be used per column in `arr`, otherwise per entire `arr`. Each configuration will be resolved using `PatternRanges.resolve_search_config` to prepare arguments for the use in Numba. After all the search configurations have been resolved, uses `vectorbtpro.utils.execution.execute` to loop over each configuration and execute it using `vectorbtpro.generic.nb.records.find_pattern_1d_nb`. The results are then concatenated into a single records array and wrapped with `PatternRanges`. If `attach_as_close` is True, will attach `arr` as `close`. `**kwargs` will be passed to `PatternRanges.__init__`.""" if seed is not None: set_seed(seed) if clean_index_kwargs is None: clean_index_kwargs = {} arr = to_pd_array(arr) arr_2d = to_2d_array(arr) arr_wrapper = ArrayWrapper.from_obj(arr) psc_keys = [a.name for a in PSC.fields if a.name != "name"] method_locals = locals() method_locals = {k: v for k, v in method_locals.items() if k in psc_keys} # Flatten search configs flat_search_configs = [] psc_names = [] psc_names_none = True n_configs = 0 if search_configs is not None: for maybe_search_config in search_configs: if isinstance(maybe_search_config, dict): maybe_search_config = PSC(**maybe_search_config) if isinstance(maybe_search_config, PSC): for col in range(arr_2d.shape[1]): flat_search_configs.append(maybe_search_config) if maybe_search_config.name is not None: psc_names.append(maybe_search_config.name) psc_names_none = False else: psc_names.append(n_configs) n_configs += 1 else: if len(maybe_search_config) != arr_2d.shape[1]: raise ValueError("Sub-list with PSC instances must match the number of columns") for col, search_config in enumerate(maybe_search_config): if isinstance(search_config, dict): search_config = PSC(**search_config) flat_search_configs.append(search_config) if search_config.name is not None: psc_names.append(search_config.name) psc_names_none = False else: psc_names.append(n_configs) n_configs += 1 # Combine parameters param_dct = {} for k, v in method_locals.items(): if k in psc_keys and isinstance(v, Param): param_dct[k] = v param_columns = None if len(param_dct) > 0: param_product, param_columns = combine_params( param_dct, random_subset=random_subset, clean_index_kwargs=clean_index_kwargs, ) if len(flat_search_configs) == 0: flat_search_configs = [] for i in range(len(param_columns)): search_config = dict() for k, v in param_product.items(): search_config[k] = v[i] for col in range(arr_2d.shape[1]): flat_search_configs.append(PSC(**search_config)) else: new_flat_search_configs = [] for i in range(len(param_columns)): for search_config in flat_search_configs: new_search_config = dict() for k, v in search_config.asdict().items(): if v is not MISSING: if k in param_product: raise ValueError(f"Parameter '{k}' is re-defined in a search configuration") new_search_config[k] = v if k in param_product: new_search_config[k] = param_product[k][i] new_flat_search_configs.append(PSC(**new_search_config)) flat_search_configs = new_flat_search_configs # Create config from arguments if empty if len(flat_search_configs) == 0: single_group = True for col in range(arr_2d.shape[1]): flat_search_configs.append(PSC()) else: single_group = False # Prepare function and arguments tasks = [] func = jit_reg.resolve_option(nb.find_pattern_1d_nb, jitted) def_func_kwargs = get_func_kwargs(func) new_search_configs = [] for c in range(len(flat_search_configs)): func_kwargs = { "col": c, "arr": arr_2d[:, c % arr_2d.shape[1]], } new_search_config = cls.resolve_search_config(flat_search_configs[c], **method_locals) for k, v in new_search_config.asdict().items(): if k == "name": continue if isinstance(v, Param): raise TypeError(f"Cannot use Param inside search configs") if k in def_func_kwargs: if v is not def_func_kwargs[k]: func_kwargs[k] = v else: func_kwargs[k] = v tasks.append(Task(func, **func_kwargs)) new_search_configs.append(new_search_config) # Build column hierarchy n_config_params = len(psc_names) // arr_2d.shape[1] if param_columns is not None: if n_config_params == 0 or (n_config_params == 1 and psc_names_none): new_columns = combine_indexes((param_columns, arr_wrapper.columns), **clean_index_kwargs) else: search_config_index = pd.Index(psc_names, name="search_config") base_columns = stack_indexes( (search_config_index, tile_index(arr_wrapper.columns, n_config_params)), **clean_index_kwargs, ) new_columns = combine_indexes((param_columns, base_columns), **clean_index_kwargs) else: if n_config_params == 0 or (n_config_params == 1 and psc_names_none): new_columns = arr_wrapper.columns else: search_config_index = pd.Index(psc_names, name="search_config") new_columns = stack_indexes( (search_config_index, tile_index(arr_wrapper.columns, n_config_params)), **clean_index_kwargs, ) # Execute each configuration execute_kwargs = merge_dicts(dict(show_progress=False if single_group else None), execute_kwargs) result_list = execute(tasks, keys=new_columns, **execute_kwargs) records_arr = np.concatenate(result_list) # Wrap with class wrapper = ArrayWrapper( **merge_dicts( dict( index=arr_wrapper.index, columns=new_columns, ), wrapper_kwargs, ) ) if attach_as_close and "close" not in kwargs: kwargs["close"] = arr if "open" in kwargs and kwargs["open"] is not None: kwargs["open"] = to_2d_array(kwargs["open"]) kwargs["open"] = tile(kwargs["open"], len(wrapper.columns) // kwargs["open"].shape[1]) if "high" in kwargs and kwargs["high"] is not None: kwargs["high"] = to_2d_array(kwargs["high"]) kwargs["high"] = tile(kwargs["high"], len(wrapper.columns) // kwargs["high"].shape[1]) if "low" in kwargs and kwargs["low"] is not None: kwargs["low"] = to_2d_array(kwargs["low"]) kwargs["low"] = tile(kwargs["low"], len(wrapper.columns) // kwargs["low"].shape[1]) if "close" in kwargs and kwargs["close"] is not None: kwargs["close"] = to_2d_array(kwargs["close"]) kwargs["close"] = tile(kwargs["close"], len(wrapper.columns) // kwargs["close"].shape[1]) return cls(wrapper, records_arr, new_search_configs, **kwargs) def with_delta(self, *args, **kwargs): """Pass self to `Ranges.from_delta` but with the index set to the last index.""" if "idx_field_or_arr" not in kwargs: kwargs["idx_field_or_arr"] = self.last_idx.values return Ranges.from_delta(self, *args, **kwargs) @classmethod def resolve_row_stack_kwargs( cls: tp.Type[PatternRangesT], *objs: tp.MaybeTuple[PatternRangesT], **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `PatternRanges` after stacking along columns.""" kwargs = Ranges.resolve_row_stack_kwargs(*objs, **kwargs) if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, PatternRanges): raise TypeError("Each object to be merged must be an instance of PatternRanges") new_search_configs = [] for obj in objs: if len(obj.search_configs) == 1: new_search_configs.append(obj.search_configs * len(kwargs["wrapper"].columns)) else: new_search_configs.append(obj.search_configs) if len(new_search_configs) >= 2: if new_search_configs[-1] != new_search_configs[0]: raise ValueError(f"Objects to be merged must have compatible PSC instances. Pass to override.") kwargs["search_configs"] = new_search_configs[0] return kwargs @classmethod def resolve_column_stack_kwargs( cls: tp.Type[PatternRangesT], *objs: tp.MaybeTuple[PatternRangesT], **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `PatternRanges` after stacking along columns.""" kwargs = Ranges.resolve_column_stack_kwargs(*objs, **kwargs) kwargs.pop("reindex_kwargs", None) if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, PatternRanges): raise TypeError("Each object to be merged must be an instance of PatternRanges") kwargs["search_configs"] = [search_config for obj in objs for search_config in obj.search_configs] return kwargs def __init__( self, wrapper: ArrayWrapper, records_arr: tp.RecordArray, search_configs: tp.List[PSC], **kwargs, ) -> None: Ranges.__init__( self, wrapper, records_arr, search_configs=search_configs, **kwargs, ) self._search_configs = search_configs def indexing_func(self: PatternRangesT, *args, ranges_meta: tp.DictLike = None, **kwargs) -> PatternRangesT: """Perform indexing on `PatternRanges`.""" if ranges_meta is None: ranges_meta = Ranges.indexing_func_meta(self, *args, **kwargs) col_idxs = ranges_meta["wrapper_meta"]["col_idxs"] if not isinstance(col_idxs, slice): col_idxs = to_1d_array(col_idxs) col_idxs = np.arange(self.wrapper.shape_2d[1])[col_idxs] new_search_configs = [] for i in col_idxs: new_search_configs.append(self.search_configs[i]) return self.replace( wrapper=ranges_meta["wrapper_meta"]["new_wrapper"], records_arr=ranges_meta["new_records_arr"], search_configs=new_search_configs, open=ranges_meta["open"], high=ranges_meta["high"], low=ranges_meta["low"], close=ranges_meta["close"], ) @property def search_configs(self) -> tp.List[PSC]: """List of `PSC` instances, one per column.""" return self._search_configs # ############# Stats ############# # _metrics: tp.ClassVar[Config] = HybridConfig( { **Ranges.metrics, "similarity": dict( title="Similarity", calc_func="similarity.describe", post_calc_func=lambda self, out, settings: { "Min": out.loc["min"], "Median": out.loc["50%"], "Max": out.loc["max"], }, tags=["pattern_ranges", "similarity"], ), } ) @property def metrics(self) -> Config: return self._metrics # ############# Plots ############# # def plot( self, column: tp.Optional[tp.Label] = None, top_n: tp.Optional[int] = None, fit_ranges: tp.Union[bool, tp.MaybeSequence[int]] = False, plot_patterns: bool = True, plot_max_error: bool = False, fill_distance: bool = True, pattern_trace_kwargs: tp.KwargsLike = None, lower_max_error_trace_kwargs: tp.KwargsLike = None, upper_max_error_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.BaseFigure: """Plot pattern ranges. Based on `Ranges.plot` and `vectorbtpro.generic.accessors.GenericSRAccessor.plot_pattern`. Args: column (str): Name of the column to plot. top_n (int): Filter top N range records by maximum duration. fit_ranges (bool, int, or sequence of int): Whether or which range records to fit. True to fit to all range records, integer or a sequence of such to fit to specific range records. plot_patterns (bool or array_like): Whether to plot `PSC.pattern`. plot_max_error (array_like): Whether to plot `PSC.max_error`. fill_distance (bool): Whether to fill the space between close and pattern. Visible for every interpolation mode except discrete. pattern_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for pattern. lower_max_error_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for lower max error. upper_max_error_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for upper max error. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **kwargs: Keyword arguments passed to `Ranges.plot`. """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if top_n is not None: self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n)) search_config = self_col.search_configs[0] if isinstance(fit_ranges, bool) and not fit_ranges: fit_ranges = None if fit_ranges is not None: if fit_ranges is True: self_col = self_col.iloc[self_col.values["start_idx"][0] : self_col.values["end_idx"][-1] + 1] elif checks.is_int(fit_ranges): self_col = self_col.apply_mask(self_col.id_arr == fit_ranges) self_col = self_col.iloc[self_col.values["start_idx"][0] : self_col.values["end_idx"][0] + 1] else: self_col = self_col.apply_mask(np.isin(self_col.id_arr, fit_ranges)) self_col = self_col.iloc[self_col.values["start_idx"][0] : self_col.values["end_idx"][0] + 1] if pattern_trace_kwargs is None: pattern_trace_kwargs = {} if lower_max_error_trace_kwargs is None: lower_max_error_trace_kwargs = {} if upper_max_error_trace_kwargs is None: upper_max_error_trace_kwargs = {} open_shape_kwargs = merge_dicts( dict(fillcolor=plotting_cfg["contrast_color_schema"]["blue"]), kwargs.pop("open_shape_kwargs", None), ) closed_shape_kwargs = merge_dicts( dict(fillcolor=plotting_cfg["contrast_color_schema"]["blue"]), kwargs.pop("closed_shape_kwargs", None), ) fig, close = Ranges.plot( self_col, return_close=True, open_shape_kwargs=open_shape_kwargs, closed_shape_kwargs=closed_shape_kwargs, add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, **kwargs, ) if self_col.count() > 0: # Extract information start_idx = self_col.get_map_field_to_index("start_idx", minus_one_to_zero=True) end_idx = self_col.get_map_field_to_index("end_idx") status = self_col.get_field_arr("status") if plot_patterns: # Plot pattern for r in range(len(start_idx)): _start_idx = start_idx[r] _end_idx = end_idx[r] if close is None: raise ValueError("Must provide close to overlay patterns") arr_sr = close.loc[_start_idx:_end_idx] if status[r] == enums.RangeStatus.Closed: arr_sr = arr_sr.iloc[:-1] if fill_distance: obj_trace_kwargs = dict( line=dict(color="rgba(0, 0, 0, 0)", width=0), opacity=0, hoverinfo="skip", showlegend=False, name=None, ) else: obj_trace_kwargs = None _pattern_trace_kwargs = merge_dicts( dict( legendgroup="pattern", showlegend=r == 0, ), pattern_trace_kwargs, ) _lower_max_error_trace_kwargs = merge_dicts( dict( legendgroup="max_error", showlegend=r == 0, ), lower_max_error_trace_kwargs, ) _upper_max_error_trace_kwargs = merge_dicts( dict( legendgroup="max_error", showlegend=False, ), upper_max_error_trace_kwargs, ) fig = arr_sr.vbt.plot_pattern( pattern=search_config.pattern, interp_mode=search_config.interp_mode, rescale_mode=search_config.rescale_mode, vmin=search_config.vmin, vmax=search_config.vmax, pmin=search_config.pmin, pmax=search_config.pmax, invert=search_config.invert, error_type=search_config.error_type, max_error=search_config.max_error if plot_max_error else np.nan, max_error_interp_mode=search_config.max_error_interp_mode, plot_obj=fill_distance, fill_distance=fill_distance, obj_trace_kwargs=obj_trace_kwargs, pattern_trace_kwargs=_pattern_trace_kwargs, lower_max_error_trace_kwargs=_lower_max_error_trace_kwargs, upper_max_error_trace_kwargs=_upper_max_error_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig PatternRanges.override_field_config_doc(__pdoc__) PatternRanges.override_metrics_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Mixin class for working with simulation ranges.""" import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.indexing import AutoIdxr from vectorbtpro.base.reshaping import broadcast_array_to from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb from vectorbtpro.utils import checks from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.decorators import hybrid_method SimRangeMixinT = tp.TypeVar("SimRangeMixinT", bound="SimRangeMixin") class SimRangeMixin(Base): """Mixin class for working with simulation ranges. Should be subclassed by a subclass of `vectorbtpro.base.wrapping.Wrapping`.""" @classmethod def row_stack_sim_start( cls, new_wrapper: ArrayWrapper, *objs: tp.MaybeTuple[SimRangeMixinT], ) -> tp.Optional[tp.ArrayLike]: """Row-stack simulation start.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if objs[0]._sim_start is not None: new_sim_start = broadcast_array_to(objs[0]._sim_start, len(new_wrapper.columns)) else: new_sim_start = None for obj in objs[1:]: if obj._sim_start is not None: raise ValueError("Objects to be merged (except the first one) must have 'sim_start=None'") return new_sim_start @classmethod def row_stack_sim_end( cls, new_wrapper: ArrayWrapper, *objs: tp.MaybeTuple[SimRangeMixinT], ) -> tp.Optional[tp.ArrayLike]: """Row-stack simulation end.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if objs[-1]._sim_end is not None: new_sim_end = len(new_wrapper.index) - len(objs[-1].wrapper.index) + objs[-1]._sim_end new_sim_end = broadcast_array_to(new_sim_end, len(new_wrapper.columns)) else: new_sim_end = None for obj in objs[:-1]: if obj._sim_end is not None: raise ValueError("Objects to be merged (except the last one) must have 'sim_end=None'") return new_sim_end @classmethod def column_stack_sim_start( cls, new_wrapper: ArrayWrapper, *objs: tp.MaybeTuple[SimRangeMixinT], ) -> tp.Optional[tp.ArrayLike]: """Column-stack simulation start.""" if len(objs) == 1: objs = objs[0] objs = list(objs) stack_sim_start_objs = False for obj in objs: if obj._sim_start is not None: stack_sim_start_objs = True break if stack_sim_start_objs: obj_sim_starts = [] for obj in objs: obj_sim_start = np.empty(len(obj._sim_start), dtype=int_) for i in range(len(obj._sim_start)): if obj._sim_start[i] == 0: obj_sim_start[i] = 0 elif obj._sim_start[i] == len(obj.wrapper.index): obj_sim_start[i] = len(new_wrapper.index) else: _obj_sim_start = new_wrapper.index.get_indexer([obj.wrapper.index[obj._sim_start[i]]])[0] if _obj_sim_start == -1: _obj_sim_start = 0 obj_sim_start[i] = _obj_sim_start obj_sim_starts.append(obj_sim_start) new_sim_start = new_wrapper.concat_arrs(*obj_sim_starts, wrap=False) else: new_sim_start = None return new_sim_start @classmethod def column_stack_sim_end( cls, new_wrapper: ArrayWrapper, *objs: tp.MaybeTuple[SimRangeMixinT], ) -> tp.Optional[tp.ArrayLike]: """Column-stack simulation end.""" if len(objs) == 1: objs = objs[0] objs = list(objs) stack_sim_end_objs = False for obj in objs: if obj._sim_end is not None: stack_sim_end_objs = True break if stack_sim_end_objs: obj_sim_ends = [] for obj in objs: obj_sim_end = np.empty(len(obj._sim_end), dtype=int_) for i in range(len(obj._sim_end)): if obj._sim_end[i] == 0: obj_sim_end[i] = 0 elif obj._sim_end[i] == len(obj.wrapper.index): obj_sim_end[i] = len(new_wrapper.index) else: _obj_sim_end = new_wrapper.index.get_indexer([obj.wrapper.index[obj._sim_end[i]]])[0] if _obj_sim_end == -1: _obj_sim_end = 0 obj_sim_end[i] = _obj_sim_end obj_sim_ends.append(obj_sim_end) new_sim_end = new_wrapper.concat_arrs(*obj_sim_ends, wrap=False) else: new_sim_end = None return new_sim_end def __init__( self, sim_start: tp.Optional[tp.Array1d] = None, sim_end: tp.Optional[tp.Array1d] = None, ) -> None: sim_start = type(self).resolve_sim_start(sim_start=sim_start, wrapper=self.wrapper, group_by=False) sim_end = type(self).resolve_sim_end(sim_end=sim_end, wrapper=self.wrapper, group_by=False) self._sim_start = sim_start self._sim_end = sim_end def sim_start_indexing_func(self, wrapper_meta: dict) -> tp.Optional[tp.ArrayLike]: """Indexing function for simulation start.""" if self._sim_start is None: new_sim_start = None elif not wrapper_meta["rows_changed"]: new_sim_start = self._sim_start else: if checks.is_int(wrapper_meta["row_idxs"]): new_sim_start = self._sim_start - wrapper_meta["row_idxs"] elif isinstance(wrapper_meta["row_idxs"], slice): new_sim_start = self._sim_start - wrapper_meta["row_idxs"].start else: new_sim_start = self._sim_start - wrapper_meta["row_idxs"][0] new_sim_start = np.clip(new_sim_start, 0, len(wrapper_meta["new_wrapper"].index)) return new_sim_start def sim_end_indexing_func(self, wrapper_meta: dict) -> tp.Optional[tp.ArrayLike]: """Indexing function for simulation end.""" if self._sim_end is None: new_sim_end = None elif not wrapper_meta["rows_changed"]: new_sim_end = self._sim_end else: if checks.is_int(wrapper_meta["row_idxs"]): new_sim_end = self._sim_end - wrapper_meta["row_idxs"] elif isinstance(wrapper_meta["row_idxs"], slice): new_sim_end = self._sim_end - wrapper_meta["row_idxs"].start else: new_sim_end = self._sim_end - wrapper_meta["row_idxs"][0] new_sim_end = np.clip(new_sim_end, 0, len(wrapper_meta["new_wrapper"].index)) return new_sim_end def resample_sim_start(self, new_wrapper: ArrayWrapper) -> tp.Optional[tp.ArrayLike]: """Resample simulation start.""" if self._sim_start is not None: new_sim_start = np.empty(len(self._sim_start), dtype=int_) for i in range(len(self._sim_start)): if self._sim_start[i] == 0: new_sim_start[i] = 0 elif self._sim_start[i] == len(self.wrapper.index): new_sim_start[i] = len(new_wrapper.index) else: _new_sim_start = new_wrapper.index.get_indexer( [self.wrapper.index[self._sim_start[i]]], method="ffill", )[0] if _new_sim_start == -1: _new_sim_start = 0 new_sim_start[i] = _new_sim_start else: new_sim_start = None return new_sim_start def resample_sim_end(self, new_wrapper: ArrayWrapper) -> tp.Optional[tp.ArrayLike]: """Resample simulation end.""" if self._sim_end is not None: new_sim_end = np.empty(len(self._sim_end), dtype=int_) for i in range(len(self._sim_end)): if self._sim_end[i] == 0: new_sim_end[i] = 0 elif self._sim_end[i] == len(self.wrapper.index): new_sim_end[i] = len(new_wrapper.index) else: _new_sim_end = new_wrapper.index.get_indexer( [self.wrapper.index[self._sim_end[i]]], method="bfill", )[0] if _new_sim_end == -1: _new_sim_end = len(new_wrapper.index) new_sim_end[i] = _new_sim_end else: new_sim_end = None return new_sim_end @hybrid_method def resolve_sim_start_value( cls_or_self, value: tp.Scalar, wrapper: tp.Optional[ArrayWrapper] = None, ) -> int: """Resolve a single value of simulation start.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") auto_idxr = AutoIdxr(value, indexer_method="bfill", below_to_zero=True) return auto_idxr.get(wrapper.index, freq=wrapper.freq) @hybrid_method def resolve_sim_end_value( cls_or_self, value: tp.Scalar, wrapper: tp.Optional[ArrayWrapper] = None, ) -> int: """Resolve a single value of simulation end.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") auto_idxr = AutoIdxr(value, indexer_method="bfill", above_to_len=True) return auto_idxr.get(wrapper.index, freq=wrapper.freq) @hybrid_method def resolve_sim_start( cls_or_self, sim_start: tp.Optional[tp.ArrayLike] = None, allow_none: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, ) -> tp.Optional[tp.ArrayLike]: """Resolve simulation start.""" already_resolved = False if not isinstance(cls_or_self, type): if sim_start is None: sim_start = cls_or_self._sim_start already_resolved = True if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") if sim_start is False: sim_start = None if allow_none and sim_start is None: return None if not already_resolved and sim_start is not None: sim_start_arr = np.asarray(sim_start) if not np.issubdtype(sim_start_arr.dtype, np.integer): if sim_start_arr.ndim == 0: sim_start = cls_or_self.resolve_sim_start_value(sim_start, wrapper=wrapper) else: new_sim_start = np.empty(len(sim_start), dtype=int_) for i in range(len(sim_start)): new_sim_start[i] = cls_or_self.resolve_sim_start_value(sim_start[i], wrapper=wrapper) sim_start = new_sim_start if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) sim_start = nb.resolve_grouped_sim_start_nb( wrapper.shape_2d, group_lens, sim_start=sim_start, allow_none=allow_none, check_bounds=not already_resolved, ) elif not already_resolved and wrapper.grouper.is_grouped(): group_lens = wrapper.grouper.get_group_lens() sim_start = nb.resolve_ungrouped_sim_start_nb( wrapper.shape_2d, group_lens, sim_start=sim_start, allow_none=allow_none, check_bounds=not already_resolved, ) else: sim_start = nb.resolve_sim_start_nb( wrapper.shape_2d, sim_start=sim_start, allow_none=allow_none, check_bounds=not already_resolved, ) return sim_start @hybrid_method def resolve_sim_end( cls_or_self, sim_end: tp.Optional[tp.ArrayLike] = None, allow_none: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, ) -> tp.Optional[tp.ArrayLike]: """Resolve simulation end.""" already_resolved = False if not isinstance(cls_or_self, type): if sim_end is None: sim_end = cls_or_self._sim_end already_resolved = True if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") if sim_end is False: sim_end = None if allow_none and sim_end is None: return None if not already_resolved and sim_end is not None: sim_end_arr = np.asarray(sim_end) if not np.issubdtype(sim_end_arr.dtype, np.integer): if sim_end_arr.ndim == 0: sim_end = cls_or_self.resolve_sim_end_value(sim_end, wrapper=wrapper) else: new_sim_end = np.empty(len(sim_end), dtype=int_) for i in range(len(sim_end)): new_sim_end[i] = cls_or_self.resolve_sim_end_value(sim_end[i], wrapper=wrapper) sim_end = new_sim_end if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) sim_end = nb.resolve_grouped_sim_end_nb( wrapper.shape_2d, group_lens, sim_end=sim_end, allow_none=allow_none, check_bounds=not already_resolved, ) elif not already_resolved and wrapper.grouper.is_grouped(): group_lens = wrapper.grouper.get_group_lens() sim_end = nb.resolve_ungrouped_sim_end_nb( wrapper.shape_2d, group_lens, sim_end=sim_end, allow_none=allow_none, check_bounds=not already_resolved, ) else: sim_end = nb.resolve_sim_end_nb( wrapper.shape_2d, sim_end=sim_end, allow_none=allow_none, check_bounds=not already_resolved, ) return sim_end @hybrid_method def get_sim_start( cls_or_self, sim_start: tp.Optional[tp.ArrayLike] = None, keep_flex: bool = False, allow_none: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[None, tp.Array1d, tp.Series]: """Get simulation start.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start( sim_start=sim_start, allow_none=allow_none, wrapper=wrapper, group_by=group_by, ) if sim_start is None: return None if keep_flex: return sim_start wrap_kwargs = merge_dicts(dict(name_or_index="sim_end"), wrap_kwargs) return wrapper.wrap_reduced(sim_start, group_by=group_by, **wrap_kwargs) @property def sim_start(self) -> tp.Series: """`SimRangeMixin.get_sim_start` with default arguments.""" return self.get_sim_start() @hybrid_method def get_sim_end( cls_or_self, sim_end: tp.Optional[tp.ArrayLike] = None, keep_flex: bool = False, allow_none: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[None, tp.Array1d, tp.Series]: """Get simulation end.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_end = cls_or_self.resolve_sim_end( sim_end=sim_end, allow_none=allow_none, wrapper=wrapper, group_by=group_by, ) if sim_end is None: return None if keep_flex: return sim_end wrap_kwargs = merge_dicts(dict(name_or_index="sim_start"), wrap_kwargs) return wrapper.wrap_reduced(sim_end, group_by=group_by, **wrap_kwargs) @property def sim_end(self) -> tp.Series: """`SimRangeMixin.get_sim_end` with default arguments.""" return self.get_sim_end() @hybrid_method def get_sim_start_index( cls_or_self, sim_start: tp.Optional[tp.ArrayLike] = None, allow_none: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Optional[tp.Series]: """Get index of simulation start.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start( sim_start=sim_start, allow_none=allow_none, wrapper=wrapper, group_by=group_by, ) if sim_start is None: return None start_index = [] for i in range(len(sim_start)): _sim_start = sim_start[i] if _sim_start == 0: start_index.append(wrapper.index[0]) elif _sim_start == len(wrapper.index): if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None: start_index.append(wrapper.index[-1] + wrapper.freq) elif isinstance(wrapper.index, pd.RangeIndex): start_index.append(wrapper.index[-1] + 1) else: start_index.append(None) else: start_index.append(wrapper.index[_sim_start]) wrap_kwargs = merge_dicts(dict(name_or_index="sim_start_index"), wrap_kwargs) return wrapper.wrap_reduced(pd.Index(start_index), group_by=group_by, **wrap_kwargs) @property def sim_start_index(self) -> tp.Series: """`SimRangeMixin.get_sim_start_index` with default arguments.""" return self.get_sim_start_index() @hybrid_method def get_sim_end_index( cls_or_self, sim_end: tp.Optional[tp.ArrayLike] = None, allow_none: bool = False, inclusive: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Optional[tp.Series]: """Get index of simulation end.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_end = cls_or_self.resolve_sim_end( sim_end=sim_end, allow_none=allow_none, wrapper=wrapper, group_by=group_by, ) if sim_end is None: return None end_index = [] for i in range(len(sim_end)): _sim_end = sim_end[i] if _sim_end == 0: if inclusive: end_index.append(None) else: end_index.append(wrapper.index[0]) elif _sim_end == len(wrapper.index): if inclusive: end_index.append(wrapper.index[-1]) else: if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None: end_index.append(wrapper.index[-1] + wrapper.freq) elif isinstance(wrapper.index, pd.RangeIndex): end_index.append(wrapper.index[-1] + 1) else: end_index.append(None) else: if inclusive: end_index.append(wrapper.index[_sim_end - 1]) else: end_index.append(wrapper.index[_sim_end]) wrap_kwargs = merge_dicts(dict(name_or_index="sim_end_index"), wrap_kwargs) return wrapper.wrap_reduced(pd.Index(end_index), group_by=group_by, **wrap_kwargs) @property def sim_end_index(self) -> tp.Series: """`SimRangeMixin.get_sim_end_index` with default arguments.""" return self.get_sim_end_index() @hybrid_method def get_sim_duration( cls_or_self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Optional[tp.Series]: """Get duration of simulation range.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start( sim_start=sim_start, allow_none=False, wrapper=wrapper, group_by=group_by, ) sim_end = cls_or_self.resolve_sim_end( sim_end=sim_end, allow_none=False, wrapper=wrapper, group_by=group_by, ) total_duration = sim_end - sim_start wrap_kwargs = merge_dicts(dict(name_or_index="sim_duration"), wrap_kwargs) return wrapper.wrap_reduced(total_duration, group_by=group_by, **wrap_kwargs) @property def sim_duration(self) -> tp.Series: """`SimRangeMixin.get_sim_duration` with default arguments.""" return self.get_sim_duration() @hybrid_method def fit_fig_to_sim_range( cls_or_self, fig: tp.BaseFigure, column: tp.Optional[tp.Label] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, xref: tp.Optional[str] = None, ) -> tp.BaseFigure: """Fit figure to simulation range.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.get_sim_start( sim_start=sim_start, allow_none=True, wrapper=wrapper, group_by=group_by, ) sim_end = cls_or_self.get_sim_end( sim_end=sim_end, allow_none=True, wrapper=wrapper, group_by=group_by, ) if sim_start is not None: sim_start = wrapper.select_col_from_obj(sim_start, column=column, group_by=group_by) if sim_end is not None: sim_end = wrapper.select_col_from_obj(sim_end, column=column, group_by=group_by) if sim_start is not None or sim_end is not None: if sim_start == len(wrapper.index) or sim_end == 0 or sim_start == sim_end: return fig if sim_start is None: sim_start = 0 if sim_start > 0: sim_start_index = wrapper.index[sim_start - 1] else: if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None: sim_start_index = wrapper.index[0] - wrapper.freq elif isinstance(wrapper.index, pd.RangeIndex): sim_start_index = wrapper.index[0] - 1 else: sim_start_index = wrapper.index[0] if sim_end is None: sim_end = len(wrapper.index) if sim_end < len(wrapper.index): sim_end_index = wrapper.index[sim_end] else: if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None: sim_end_index = wrapper.index[-1] + wrapper.freq elif isinstance(wrapper.index, pd.RangeIndex): sim_end_index = wrapper.index[-1] + 1 else: sim_end_index = wrapper.index[-1] if xref is not None: xaxis = "xaxis" + xref[1:] fig.update_layout(**{xaxis: dict(range=[sim_start_index, sim_end_index])}) else: fig.update_xaxes(range=[sim_start_index, sim_end_index]) return fig # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Mixin for building statistics out of performance metrics.""" import inspect import string from collections import Counter import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.wrapping import Wrapping from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import get_dict_attr, AttrResolverMixin from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig from vectorbtpro.utils.parsing import get_func_arg_names, get_forward_args from vectorbtpro.utils.tagging import match_tags from vectorbtpro.utils.template import substitute_templates, CustomTemplate from vectorbtpro.utils.warnings_ import warn __all__ = [] class MetaStatsBuilderMixin(type): """Metaclass for `StatsBuilderMixin`.""" @property def metrics(cls) -> Config: """Metrics supported by `StatsBuilderMixin.stats`.""" return cls._metrics class StatsBuilderMixin(Base, metaclass=MetaStatsBuilderMixin): """Mixin that implements `StatsBuilderMixin.stats`. Required to be a subclass of `vectorbtpro.base.wrapping.Wrapping`.""" _writeable_attrs: tp.WriteableAttrs = {"_metrics"} def __init__(self) -> None: checks.assert_instance_of(self, Wrapping) # Copy writeable attrs self._metrics = type(self)._metrics.copy() @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `StatsBuilderMixin.stats`.""" return dict(settings=dict(freq=self.wrapper.freq)) def resolve_stats_setting( self, value: tp.Optional[tp.Any], key: str, merge: bool = False, ) -> tp.Any: """Resolve a setting for `StatsBuilderMixin.stats`.""" from vectorbtpro._settings import settings as _settings stats_builder_cfg = _settings["stats_builder"] if merge: return merge_dicts( stats_builder_cfg[key], self.stats_defaults.get(key, {}), value, ) if value is not None: return value return self.stats_defaults.get(key, stats_builder_cfg[key]) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), ) ) @property def metrics(self) -> Config: """Metrics supported by `${cls_name}`. ```python ${metrics} ``` Returns `${cls_name}._metrics`, which gets (hybrid-) copied upon creation of each instance. Thus, changing this config won't affect the class. To change metrics, you can either change the config in-place, override this property, or overwrite the instance variable `${cls_name}._metrics`.""" return self._metrics def stats( self, metrics: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[str, tp.Kwargs]]]] = None, tags: tp.Optional[tp.MaybeIterable[str]] = None, column: tp.Optional[tp.Label] = None, group_by: tp.GroupByLike = None, per_column: tp.Optional[bool] = None, split_columns: tp.Optional[bool] = None, agg_func: tp.Optional[tp.Callable] = np.mean, dropna: tp.Optional[bool] = None, silence_warnings: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, settings: tp.KwargsLike = None, filters: tp.KwargsLike = None, metric_settings: tp.KwargsLike = None, ) -> tp.Optional[tp.SeriesFrame]: """Compute various metrics on this object. Args: metrics (str, tuple, iterable, or dict): Metrics to calculate. Each element can be either: * Metric name (see keys in `StatsBuilderMixin.metrics`) * Tuple of a metric name and a settings dict as in `StatsBuilderMixin.metrics` * Tuple of a metric name and a template of instance `vectorbtpro.utils.template.CustomTemplate` * Tuple of a metric name and a list of settings dicts to be expanded into multiple metrics The settings dict can contain the following keys: * `title`: Title of the metric. Defaults to the name. * `tags`: Single or multiple tags to associate this metric with. If any of these tags is in `tags`, keeps this metric. * `check_{filter}` and `inv_check_{filter}`: Whether to check this metric against a filter defined in `filters`. True (or False for inverse) means to keep this metric. * `calc_func` (required): Calculation function for custom metrics. Must return either a scalar for one column/group, pd.Series for multiple columns/groups, or a dict of such for multiple sub-metrics. * `resolve_calc_func`: whether to resolve `calc_func`. If the function can be accessed by traversing attributes of this object, you can specify the path to this function as a string (see `vectorbtpro.utils.attr_.deep_getattr` for the path format). If `calc_func` is a function, arguments from merged metric settings are matched with arguments in the signature (see below). If `resolve_calc_func` is False, `calc_func` must accept (resolved) self and dictionary of merged metric settings. Defaults to True. * `use_shortcuts`: Whether to use shortcut properties whenever possible when resolving `calc_func`. Defaults to True. * `post_calc_func`: Function to post-process the result of `calc_func`. Must accept (resolved) self, output of `calc_func`, and dictionary of merged metric settings, and return whatever is acceptable to be returned by `calc_func`. Defaults to None. * `fill_wrap_kwargs`: Whether to fill `wrap_kwargs` with `to_timedelta` and `silence_warnings`. Defaults to False. * `apply_to_timedelta`: Whether to apply `vectorbtpro.base.wrapping.ArrayWrapper.arr_to_timedelta` on the result. To disable this globally, pass `to_timedelta=False` in `settings`. Defaults to False. * `pass_{arg}`: Whether to pass any argument from the settings (see below). Defaults to True if this argument was found in the function's signature. Set to False to not pass. If argument to be passed was not found, `pass_{arg}` is removed. * `resolve_path_{arg}`: Whether to resolve an argument that is meant to be an attribute of this object and is the first part of the path of `calc_func`. Passes only optional arguments. Defaults to True. See `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`. * `resolve_{arg}`: Whether to resolve an argument that is meant to be an attribute of this object and is present in the function's signature. Defaults to False. See `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`. * `use_shortcuts_{arg}`: Whether to use shortcut properties whenever possible when resolving an argument. Defaults to True. * `select_col_{arg}`: Whether to select the column from an argument that is meant to be an attribute of this object. Defaults to False. * `template_context`: Mapping to replace templates in metric settings. Used across all settings. * Any other keyword argument that overrides the settings or is passed directly to `calc_func`. If `resolve_calc_func` is True, the calculation function may "request" any of the following arguments by accepting them or if `pass_{arg}` was found in the settings dict: * Each of `vectorbtpro.utils.attr_.AttrResolverMixin.self_aliases`: original object (ungrouped, with no column selected) * `group_by`: won't be passed if it was used in resolving the first attribute of `calc_func` specified as a path, use `pass_group_by=True` to pass anyway * `column` * `metric_name` * `agg_func` * `silence_warnings` * `to_timedelta`: replaced by True if None and frequency is set * Any argument from `settings` * Any attribute of this object if it meant to be resolved (see `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`) Pass `metrics='all'` to calculate all supported metrics. tags (str or iterable): Tags to select. See `vectorbtpro.utils.tagging.match_tags`. column (str): Name of the column/group. !!! hint There are two ways to select a column: `obj['a'].stats()` and `obj.stats(column='a')`. They both accomplish the same thing but in different ways: `obj['a'].stats()` computes statistics of the column 'a' only, while `obj.stats(column='a')` computes statistics of all columns first and only then selects the column 'a'. The first method is preferred when you have a lot of data or caching is disabled. The second method is preferred when most attributes have already been cached. group_by (any): Group or ungroup columns. See `vectorbtpro.base.grouping.base.Grouper`. per_column (bool): Whether to compute per column and then stack along columns. split_columns (bool): Whether to split this instance into multiple columns when `per_column` is True. Otherwise, iterates over columns and passes `column` to the whole instance. agg_func (callable): Aggregation function to aggregate statistics across all columns. By default, takes the mean of all columns. If None, returns all columns as a DataFrame. Must take `pd.Series` and return a const. Takes effect if `column` was specified or this object contains only one column of data. If `agg_func` has been overridden by a metric: * Takes effect if global `agg_func` is not None * Raises a warning if it's None but the result of calculation has multiple values dropna (bool): Whether to hide metrics that are all NaN. silence_warnings (bool): Whether to silence all warnings. template_context (mapping): Context used to substitute templates. Gets merged over `template_context` from `vectorbtpro._settings.stats_builder` and `StatsBuilderMixin.stats_defaults`. Applied on `settings` and then on each metric settings. filters (dict): Filters to apply. Each item consists of the filter name and settings dict. The settings dict can contain the following keys: * `filter_func`: Filter function that must accept resolved self and merged settings for a metric, and return either True or False. * `warning_message`: Warning message to be shown when skipping a metric. Can be a template that will be substituted using merged metric settings as context. Defaults to None. * `inv_warning_message`: Same as `warning_message` but for inverse checks. Gets merged over `filters` from `vectorbtpro._settings.stats_builder` and `StatsBuilderMixin.stats_defaults`. settings (dict): Global settings and resolution arguments. Extends/overrides `settings` from `vectorbtpro._settings.stats_builder` and `StatsBuilderMixin.stats_defaults`. Gets extended/overridden by metric settings. metric_settings (dict): Keyword arguments for each metric. Extends/overrides all global and metric settings. For template logic, see `vectorbtpro.utils.template`. For defaults, see `vectorbtpro._settings.stats_builder` and `StatsBuilderMixin.stats_defaults`. !!! hint There are two types of arguments: optional (or resolution) and mandatory arguments. Optional arguments are only passed if they are found in the function's signature. Mandatory arguments are passed regardless of this. Optional arguments can only be defined using `settings` (that is, globally), while mandatory arguments can be defined both using default metric settings and `{metric_name}_kwargs`. Overriding optional arguments using default metric settings or `{metric_name}_kwargs` won't turn them into mandatory. For this, pass `pass_{arg}=True`. !!! hint Make sure to resolve and then to re-use as many object attributes as possible to utilize built-in caching (even if global caching is disabled). """ # Compute per column if column is None: if per_column is None: per_column = self.resolve_stats_setting(per_column, "per_column") if per_column: columns = self.get_item_keys(group_by=group_by) if len(columns) > 1: results = [] if split_columns: for _, column_self in self.items(group_by=group_by, wrap=True): _args, _kwargs = get_forward_args(column_self.stats, locals()) results.append(column_self.stats(*_args, **_kwargs)) else: for column in columns: _args, _kwargs = get_forward_args(self.stats, locals()) results.append(self.stats(*_args, **_kwargs)) return pd.concat(results, keys=columns, axis=1) # Resolve defaults dropna = self.resolve_stats_setting(dropna, "dropna") silence_warnings = self.resolve_stats_setting(silence_warnings, "silence_warnings") template_context = self.resolve_stats_setting(template_context, "template_context", merge=True) filters = self.resolve_stats_setting(filters, "filters", merge=True) settings = self.resolve_stats_setting(settings, "settings", merge=True) metric_settings = self.resolve_stats_setting(metric_settings, "metric_settings", merge=True) # Replace templates globally (not used at metric level) if len(template_context) > 0: sub_settings = substitute_templates( settings, context=template_context, eval_id="sub_settings", strict=False, ) else: sub_settings = settings # Resolve self reself = self.resolve_self( cond_kwargs=sub_settings, impacts_caching=False, silence_warnings=silence_warnings, ) # Prepare metrics metrics = reself.resolve_stats_setting(metrics, "metrics") if metrics == "all": metrics = reself.metrics if isinstance(metrics, dict): metrics = list(metrics.items()) if isinstance(metrics, (str, tuple)): metrics = [metrics] # Prepare tags tags = reself.resolve_stats_setting(tags, "tags") if isinstance(tags, str) and tags == "all": tags = None if isinstance(tags, (str, tuple)): tags = [tags] # Bring to the same shape new_metrics = [] for i, metric in enumerate(metrics): if isinstance(metric, str): metric = (metric, reself.metrics[metric]) if not isinstance(metric, tuple): raise TypeError(f"Metric at index {i} must be either a string or a tuple") new_metrics.append(metric) metrics = new_metrics # Expand metrics new_metrics = [] for i, (metric_name, _metric_settings) in enumerate(metrics): if isinstance(_metric_settings, CustomTemplate): metric_context = merge_dicts( template_context, {name: reself for name in reself.self_aliases}, dict( column=column, group_by=group_by, metric_name=metric_name, agg_func=agg_func, silence_warnings=silence_warnings, to_timedelta=None, ), settings, ) metric_context = substitute_templates( metric_context, context=metric_context, eval_id="metric_context", ) _metric_settings = _metric_settings.substitute( context=metric_context, strict=True, eval_id="metric", ) if isinstance(_metric_settings, list): for __metric_settings in _metric_settings: new_metrics.append((metric_name, __metric_settings)) else: new_metrics.append((metric_name, _metric_settings)) metrics = new_metrics # Handle duplicate names metric_counts = Counter(list(map(lambda x: x[0], metrics))) metric_i = {k: -1 for k in metric_counts.keys()} metrics_dct = {} for i, (metric_name, _metric_settings) in enumerate(metrics): if metric_counts[metric_name] > 1: metric_i[metric_name] += 1 metric_name = metric_name + "_" + str(metric_i[metric_name]) metrics_dct[metric_name] = _metric_settings # Check metric_settings missed_keys = set(metric_settings.keys()).difference(set(metrics_dct.keys())) if len(missed_keys) > 0: raise ValueError(f"Keys {missed_keys} in metric_settings could not be matched with any metric") # Merge settings opt_arg_names_dct = {} custom_arg_names_dct = {} resolved_self_dct = {} context_dct = {} for metric_name, _metric_settings in list(metrics_dct.items()): opt_settings = merge_dicts( {name: reself for name in reself.self_aliases}, dict( column=column, group_by=group_by, metric_name=metric_name, agg_func=agg_func, silence_warnings=silence_warnings, to_timedelta=None, ), settings, ) _metric_settings = _metric_settings.copy() passed_metric_settings = metric_settings.get(metric_name, {}) merged_settings = merge_dicts(opt_settings, _metric_settings, passed_metric_settings) metric_template_context = merged_settings.pop("template_context", {}) template_context_merged = merge_dicts(template_context, metric_template_context) template_context_merged = substitute_templates( template_context_merged, context=merged_settings, eval_id="template_context_merged", ) context = merge_dicts(template_context_merged, merged_settings) merged_settings = substitute_templates( merged_settings, context=context, eval_id="merged_settings", ) # Filter by tag if tags is not None: in_tags = merged_settings.get("tags", None) if in_tags is None or not match_tags(tags, in_tags): metrics_dct.pop(metric_name, None) continue custom_arg_names = set(_metric_settings.keys()).union(set(passed_metric_settings.keys())) opt_arg_names = set(opt_settings.keys()) custom_reself = reself.resolve_self( cond_kwargs=merged_settings, custom_arg_names=custom_arg_names, impacts_caching=True, silence_warnings=merged_settings["silence_warnings"], ) metrics_dct[metric_name] = merged_settings custom_arg_names_dct[metric_name] = custom_arg_names opt_arg_names_dct[metric_name] = opt_arg_names resolved_self_dct[metric_name] = custom_reself context_dct[metric_name] = context # Filter metrics for metric_name, _metric_settings in list(metrics_dct.items()): custom_reself = resolved_self_dct[metric_name] context = context_dct[metric_name] _silence_warnings = _metric_settings.get("silence_warnings") metric_filters = set() for k in _metric_settings.keys(): filter_name = None if k.startswith("check_"): filter_name = k[len("check_") :] elif k.startswith("inv_check_"): filter_name = k[len("inv_check_") :] if filter_name is not None: if filter_name not in filters: raise ValueError(f"Metric '{metric_name}' requires filter '{filter_name}'") metric_filters.add(filter_name) for filter_name in metric_filters: filter_settings = filters[filter_name] _filter_settings = substitute_templates( filter_settings, context=context, eval_id="filter_settings", ) filter_func = _filter_settings["filter_func"] warning_message = _filter_settings.get("warning_message", None) inv_warning_message = _filter_settings.get("inv_warning_message", None) to_check = _metric_settings.get("check_" + filter_name, False) inv_to_check = _metric_settings.get("inv_check_" + filter_name, False) if to_check or inv_to_check: whether_true = filter_func(custom_reself, _metric_settings) to_remove = (to_check and not whether_true) or (inv_to_check and whether_true) if to_remove: if to_check and warning_message is not None and not _silence_warnings: warn(warning_message) if inv_to_check and inv_warning_message is not None and not _silence_warnings: warn(inv_warning_message) metrics_dct.pop(metric_name, None) custom_arg_names_dct.pop(metric_name, None) opt_arg_names_dct.pop(metric_name, None) resolved_self_dct.pop(metric_name, None) context_dct.pop(metric_name, None) break # Any metrics left? if len(metrics_dct) == 0: if not silence_warnings: warn("No metrics to calculate") return None # Compute stats arg_cache_dct = {} stats_dct = {} used_agg_func = False for i, (metric_name, _metric_settings) in enumerate(metrics_dct.items()): try: final_kwargs = _metric_settings.copy() opt_arg_names = opt_arg_names_dct[metric_name] custom_arg_names = custom_arg_names_dct[metric_name] custom_reself = resolved_self_dct[metric_name] # Clean up keys for k, v in list(final_kwargs.items()): if k.startswith("check_") or k.startswith("inv_check_") or k in ("tags",): final_kwargs.pop(k, None) # Get metric-specific values _column = final_kwargs.get("column") _group_by = final_kwargs.get("group_by") _agg_func = final_kwargs.get("agg_func") _silence_warnings = final_kwargs.get("silence_warnings") if final_kwargs["to_timedelta"] is None: final_kwargs["to_timedelta"] = custom_reself.wrapper.freq is not None to_timedelta = final_kwargs.get("to_timedelta") title = final_kwargs.pop("title", metric_name) calc_func = final_kwargs.pop("calc_func") resolve_calc_func = final_kwargs.pop("resolve_calc_func", True) post_calc_func = final_kwargs.pop("post_calc_func", None) use_shortcuts = final_kwargs.pop("use_shortcuts", True) use_caching = final_kwargs.pop("use_caching", True) fill_wrap_kwargs = final_kwargs.pop("fill_wrap_kwargs", False) if fill_wrap_kwargs: final_kwargs["wrap_kwargs"] = merge_dicts( dict(to_timedelta=to_timedelta, silence_warnings=_silence_warnings), final_kwargs.get("wrap_kwargs", None), ) apply_to_timedelta = final_kwargs.pop("apply_to_timedelta", False) # Resolve calc_func if resolve_calc_func: if not callable(calc_func): passed_kwargs_out = {} def _getattr_func( obj: tp.Any, attr: str, args: tp.ArgsLike = None, kwargs: tp.KwargsLike = None, call_attr: bool = True, _final_kwargs: tp.Kwargs = final_kwargs, _opt_arg_names: tp.Set[str] = opt_arg_names, _custom_arg_names: tp.Set[str] = custom_arg_names, _arg_cache_dct: tp.Kwargs = arg_cache_dct, _use_shortcuts: bool = use_shortcuts, _use_caching: bool = use_caching, ) -> tp.Any: if attr in _final_kwargs: return _final_kwargs[attr] if args is None: args = () if kwargs is None: kwargs = {} if obj is custom_reself: resolve_path_arg = _final_kwargs.pop("resolve_path_" + attr, True) if resolve_path_arg: if call_attr: cond_kwargs = {k: v for k, v in _final_kwargs.items() if k in _opt_arg_names} out = custom_reself.resolve_attr( attr, # do not pass _attr, important for caching args=args, cond_kwargs=cond_kwargs, kwargs=kwargs, custom_arg_names=_custom_arg_names, cache_dct=_arg_cache_dct, use_caching=_use_caching, passed_kwargs_out=passed_kwargs_out, use_shortcuts=_use_shortcuts, ) else: if isinstance(obj, AttrResolverMixin): cls_dir = obj.cls_dir else: cls_dir = dir(type(obj)) if "get_" + attr in cls_dir: _attr = "get_" + attr else: _attr = attr out = getattr(obj, _attr) _select_col_arg = _final_kwargs.pop("select_col_" + attr, False) if _select_col_arg and _column is not None: out = custom_reself.select_col_from_obj( out, _column, wrapper=custom_reself.wrapper.regroup(_group_by), ) passed_kwargs_out["group_by"] = _group_by passed_kwargs_out["column"] = _column return out out = getattr(obj, attr) if callable(out) and call_attr: return out(*args, **kwargs) return out calc_func = custom_reself.deep_getattr( calc_func, getattr_func=_getattr_func, call_last_attr=False, ) if "group_by" in passed_kwargs_out: if "pass_group_by" not in final_kwargs: final_kwargs.pop("group_by", None) if "column" in passed_kwargs_out: if "pass_column" not in final_kwargs: final_kwargs.pop("column", None) # Resolve arguments if callable(calc_func): func_arg_names = get_func_arg_names(calc_func) for k in func_arg_names: if k not in final_kwargs: resolve_arg = final_kwargs.pop("resolve_" + k, False) use_shortcuts_arg = final_kwargs.pop("use_shortcuts_" + k, True) select_col_arg = final_kwargs.pop("select_col_" + k, False) if resolve_arg: try: arg_out = custom_reself.resolve_attr( k, cond_kwargs=final_kwargs, custom_arg_names=custom_arg_names, cache_dct=arg_cache_dct, use_caching=use_caching, use_shortcuts=use_shortcuts_arg, ) except AttributeError: continue if select_col_arg and _column is not None: arg_out = custom_reself.select_col_from_obj( arg_out, _column, wrapper=custom_reself.wrapper.regroup(_group_by), ) final_kwargs[k] = arg_out for k in list(final_kwargs.keys()): if k in opt_arg_names: if "pass_" + k in final_kwargs: if not final_kwargs.get("pass_" + k): # first priority final_kwargs.pop(k, None) elif k not in func_arg_names: # second priority final_kwargs.pop(k, None) for k in list(final_kwargs.keys()): if k.startswith("pass_") or k.startswith("resolve_"): final_kwargs.pop(k, None) # cleanup # Call calc_func out = calc_func(**final_kwargs) else: # calc_func is already a result out = calc_func else: # Do not resolve calc_func out = calc_func(custom_reself, _metric_settings) # Call post_calc_func if post_calc_func is not None: out = post_calc_func(custom_reself, out, _metric_settings) # Post-process and store the metric multiple = True if not isinstance(out, dict): multiple = False out = {None: out} for k, v in out.items(): # Resolve title if multiple: if title is None: t = str(k) else: t = title + ": " + str(k) else: t = title # Check result type if checks.is_any_array(v) and not checks.is_series(v): raise TypeError( "calc_func must return either a scalar for one column/group, " "pd.Series for multiple columns/groups, or a dict of such. " f"Not {type(v)}." ) # Handle apply_to_timedelta if apply_to_timedelta and to_timedelta: v = custom_reself.wrapper.arr_to_timedelta(v, silence_warnings=_silence_warnings) # Select column or aggregate if checks.is_series(v): if _column is None and v.shape[0] == 1: v = v.iloc[0] elif _column is not None: v = custom_reself.select_col_from_obj( v, _column, wrapper=custom_reself.wrapper.regroup(_group_by), ) elif _agg_func is not None and agg_func is not None: v = _agg_func(v) if _agg_func is agg_func: used_agg_func = True elif _agg_func is None and agg_func is not None: if not _silence_warnings: warn( f"Metric '{metric_name}' returned multiple values " "despite having no aggregation function", ) continue # Store metric if t in stats_dct: if not _silence_warnings: warn(f"Duplicate metric title '{t}'") stats_dct[t] = v except Exception as e: warn(f"Metric '{metric_name}' raised an exception") raise e # Return the stats if reself.wrapper.get_ndim(group_by=group_by) == 1: sr = pd.Series( stats_dct, name=reself.wrapper.get_name(group_by=group_by), dtype=object, ) if dropna: sr.replace([np.inf, -np.inf], np.nan, inplace=True) return sr.dropna() return sr if column is not None: sr = pd.Series(stats_dct, name=column, dtype=object) if dropna: sr.replace([np.inf, -np.inf], np.nan, inplace=True) return sr.dropna() return sr if agg_func is not None: if used_agg_func and not silence_warnings: warn( f"Object has multiple columns. Aggregated some metrics using {agg_func}. " "Pass either agg_func=None or per_column=True to return statistics per column. " "Pass column to select a single column or group.", ) sr = pd.Series(stats_dct, name="agg_stats", dtype=object) if dropna: sr.replace([np.inf, -np.inf], np.nan, inplace=True) return sr.dropna() return sr new_index = reself.wrapper.grouper.get_index(group_by=group_by) df = pd.DataFrame(stats_dct, index=new_index) if dropna: df.replace([np.inf, -np.inf], np.nan, inplace=True) return df.dropna(axis=1, how="all") return df # ############# Docs ############# # @classmethod def build_metrics_doc(cls, source_cls: tp.Optional[type] = None) -> str: """Build metrics documentation.""" if source_cls is None: source_cls = StatsBuilderMixin return string.Template( inspect.cleandoc(get_dict_attr(source_cls, "metrics").__doc__), ).substitute( {"metrics": cls.metrics.prettify(), "cls_name": cls.__name__}, ) @classmethod def override_metrics_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: """Call this method on each subclass that overrides `StatsBuilderMixin.metrics`.""" __pdoc__[cls.__name__ + ".metrics"] = cls.build_metrics_doc(source_cls=source_cls) __pdoc__ = dict() StatsBuilderMixin.override_metrics_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules with custom indicators built with the indicator factory. You can access all the indicators by `vbt.*`. Run for the examples: ```pycon >>> ohlcv = vbt.YFData.pull( ... "BTC-USD", ... start="2019-03-01", ... end="2019-09-01" ... ).get() ``` """ from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.indicators.custom.adx import * from vectorbtpro.indicators.custom.atr import * from vectorbtpro.indicators.custom.bbands import * from vectorbtpro.indicators.custom.hurst import * from vectorbtpro.indicators.custom.ma import * from vectorbtpro.indicators.custom.macd import * from vectorbtpro.indicators.custom.msd import * from vectorbtpro.indicators.custom.obv import * from vectorbtpro.indicators.custom.ols import * from vectorbtpro.indicators.custom.patsim import * from vectorbtpro.indicators.custom.pivotinfo import * from vectorbtpro.indicators.custom.rsi import * from vectorbtpro.indicators.custom.sigdet import * from vectorbtpro.indicators.custom.stoch import * from vectorbtpro.indicators.custom.supertrend import * from vectorbtpro.indicators.custom.vwap import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `ADX`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "ADX", ] __pdoc__ = {} ADX = IndicatorFactory( class_name="ADX", module_name=__name__, input_names=["high", "low", "close"], param_names=["window", "wtype"], output_names=["plus_di", "minus_di", "dx", "adx"], ).with_apply_func( nb.adx_nb, kwargs_as_args=["minp", "adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="wilder", minp=None, adjust=False, ) class _ADX(ADX): """Average Directional Movement Index (ADX). The indicator is used by some traders to determine the strength of a trend. See [Average Directional Index (ADX)](https://www.investopedia.com/terms/a/adx.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, plus_di_trace_kwargs: tp.KwargsLike = None, minus_di_trace_kwargs: tp.KwargsLike = None, adx_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `ADX.plus_di`, `ADX.minus_di`, and `ADX.adx`. Args: column (str): Name of the column to plot. plus_di_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ADX.plus_di`. minus_di_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ADX.minus_di`. adx_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ADX.adx`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.ADX.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show() ``` ![](/assets/images/api/ADX.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ADX.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if plus_di_trace_kwargs is None: plus_di_trace_kwargs = {} if minus_di_trace_kwargs is None: minus_di_trace_kwargs = {} if adx_trace_kwargs is None: adx_trace_kwargs = {} plus_di_trace_kwargs = merge_dicts( dict(name="+DI", line=dict(color=plotting_cfg["color_schema"]["green"], dash="dot")), plus_di_trace_kwargs, ) minus_di_trace_kwargs = merge_dicts( dict(name="-DI", line=dict(color=plotting_cfg["color_schema"]["red"], dash="dot")), minus_di_trace_kwargs, ) adx_trace_kwargs = merge_dicts( dict(name="ADX", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), adx_trace_kwargs, ) fig = self_col.plus_di.vbt.lineplot( trace_kwargs=plus_di_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.minus_di.vbt.lineplot( trace_kwargs=minus_di_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.adx.vbt.lineplot( trace_kwargs=adx_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(ADX, "__doc__", _ADX.__doc__) setattr(ADX, "plot", _ADX.plot) ADX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `ATR`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "ATR", ] __pdoc__ = {} ATR = IndicatorFactory( class_name="ATR", module_name=__name__, input_names=["high", "low", "close"], param_names=["window", "wtype"], output_names=["tr", "atr"], ).with_apply_func( nb.atr_nb, kwargs_as_args=["minp", "adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="wilder", minp=None, adjust=False, ) class _ATR(ATR): """Average True Range (ATR). The indicator provide an indication of the degree of price volatility. Strong moves, in either direction, are often accompanied by large ranges, or large True Ranges. See [Average True Range - ATR](https://www.investopedia.com/terms/a/atr.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, tr_trace_kwargs: tp.KwargsLike = None, atr_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `ATR.tr` and `ATR.atr`. Args: column (str): Name of the column to plot. tr_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ATR.tr`. atr_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ATR.atr`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.ATR.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show() ``` ![](/assets/images/api/ATR.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ATR.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if tr_trace_kwargs is None: tr_trace_kwargs = {} if atr_trace_kwargs is None: atr_trace_kwargs = {} tr_trace_kwargs = merge_dicts( dict(name="TR", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), tr_trace_kwargs, ) atr_trace_kwargs = merge_dicts( dict(name="ATR", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])), atr_trace_kwargs, ) fig = self_col.tr.vbt.lineplot( trace_kwargs=tr_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.atr.vbt.lineplot( trace_kwargs=atr_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(ATR, "__doc__", _ATR.__doc__) setattr(ATR, "plot", _ATR.plot) ATR.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `BBANDS`.""" from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_2d_array from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.colors import adjust_opacity from vectorbtpro.utils.config import merge_dicts __all__ = [ "BBANDS", ] __pdoc__ = {} BBANDS = IndicatorFactory( class_name="BBANDS", module_name=__name__, short_name="bb", input_names=["close"], param_names=["window", "wtype", "alpha"], output_names=["upper", "middle", "lower"], lazy_outputs=dict( percent_b=lambda self: self.wrapper.wrap( nb.bbands_percent_b_nb( to_2d_array(self.close), to_2d_array(self.upper), to_2d_array(self.lower), ), ), bandwidth=lambda self: self.wrapper.wrap( nb.bbands_bandwidth_nb( to_2d_array(self.upper), to_2d_array(self.middle), to_2d_array(self.lower), ), ), ), ).with_apply_func( nb.bbands_nb, kwargs_as_args=["minp", "adjust", "ddof"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="simple", alpha=2, minp=None, adjust=False, ddof=0, ) class _BBANDS(BBANDS): """Bollinger Bands (BBANDS). A Bollinger Band® is a technical analysis tool defined by a set of lines plotted two standard deviations (positively and negatively) away from a simple moving average (SMA) of the security's price, but can be adjusted to user preferences. See [Bollinger Band®](https://www.investopedia.com/terms/b/bollingerbands.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, upper_trace_kwargs: tp.KwargsLike = None, middle_trace_kwargs: tp.KwargsLike = None, lower_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `BBANDS.upper`, `BBANDS.middle`, and `BBANDS.lower` against `BBANDS.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `BBANDS.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.close`. upper_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.upper`. middle_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.middle`. lower_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.lower`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.BBANDS.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/BBANDS.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/BBANDS.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if upper_trace_kwargs is None: upper_trace_kwargs = {} if middle_trace_kwargs is None: middle_trace_kwargs = {} if lower_trace_kwargs is None: lower_trace_kwargs = {} lower_trace_kwargs = merge_dicts( dict( name="Lower band", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)), ), lower_trace_kwargs, ) upper_trace_kwargs = merge_dicts( dict( name="Upper band", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)), fill="tonexty", fillcolor="rgba(128, 128, 128, 0.2)", ), upper_trace_kwargs, ) # default kwargs middle_trace_kwargs = merge_dicts( dict(name="Middle band", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), middle_trace_kwargs ) close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) fig = self_col.lower.vbt.lineplot( trace_kwargs=lower_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.upper.vbt.lineplot( trace_kwargs=upper_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.middle.vbt.lineplot( trace_kwargs=middle_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(BBANDS, "__doc__", _BBANDS.__doc__) setattr(BBANDS, "plot", _BBANDS.plot) BBANDS.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `HURST`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators import nb from vectorbtpro.indicators.enums import HurstMethod from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "HURST", ] __pdoc__ = {} HURST = IndicatorFactory( class_name="HURST", module_name=__name__, input_names=["close"], param_names=[ "window", "method", "max_lag", "min_log", "max_log", "log_step", "min_chunk", "max_chunk", "num_chunks", ], output_names=["hurst"], ).with_apply_func( nb.rolling_hurst_nb, kwargs_as_args=["minp", "stabilize"], param_settings=dict( method=dict( dtype=HurstMethod, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=200, method="standard", max_lag=20, min_log=1, max_log=2, log_step=0.25, min_chunk=8, max_chunk=100, num_chunks=5, minp=None, stabilize=False, ) class _HURST(HURST): """Moving Hurst exponent (HURST).""" def plot( self, column: tp.Optional[tp.Label] = None, hurst_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `HURST.hurst`. Args: column (str): Name of the column to plot. hurst_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `HURST.hurst`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> ohlcv = vbt.YFData.pull( ... "BTC-USD", ... start="2020-01-01", ... end="2024-01-01" ... ).get() >>> vbt.HURST.run(ohlcv["Close"]).plot().show() ``` ![](/assets/images/api/HURST.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/HURST.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if hurst_trace_kwargs is None: hurst_trace_kwargs = {} hurst_trace_kwargs = merge_dicts( dict(name="HURST", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), hurst_trace_kwargs, ) fig = self_col.hurst.vbt.lineplot( trace_kwargs=hurst_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **layout_kwargs, ) return fig setattr(HURST, "__doc__", _HURST.__doc__) setattr(HURST, "plot", _HURST.plot) HURST.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `MA`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "MA", ] __pdoc__ = {} MA = IndicatorFactory( class_name="MA", module_name=__name__, input_names=["close"], param_names=["window", "wtype"], output_names=["ma"], ).with_apply_func( nb.ma_nb, kwargs_as_args=["minp", "adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="simple", minp=None, adjust=False, ) class _MA(MA): """Moving Average (MA). A moving average is a widely used indicator in technical analysis that helps smooth out price action by filtering out the “noise” from random short-term price fluctuations. See [Moving Average (MA)](https://www.investopedia.com/terms/m/movingaverage.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, ma_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `MA.ma` against `MA.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `MA.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MA.close`. ma_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MA.ma`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.MA.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/MA.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/MA.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if ma_trace_kwargs is None: ma_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) ma_trace_kwargs = merge_dicts( dict(name="MA", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), ma_trace_kwargs, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.ma.vbt.lineplot( trace_kwargs=ma_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(MA, "__doc__", _MA.__doc__) setattr(MA, "plot", _MA.plot) MA.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `MACD`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_2d_array from vectorbtpro.generic import nb as generic_nb, enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.colors import adjust_opacity from vectorbtpro.utils.config import merge_dicts __all__ = [ "MACD", ] __pdoc__ = {} MACD = IndicatorFactory( class_name="MACD", module_name=__name__, input_names=["close"], param_names=["fast_window", "slow_window", "signal_window", "wtype", "macd_wtype", "signal_wtype"], output_names=["macd", "signal"], lazy_outputs=dict( hist=lambda self: self.wrapper.wrap( nb.macd_hist_nb( to_2d_array(self.macd), to_2d_array(self.signal), ), ), ), ).with_apply_func( nb.macd_nb, kwargs_as_args=["minp", "macd_minp", "signal_minp", "adjust", "macd_adjust", "signal_adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ), macd_wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ), signal_wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ), ), fast_window=12, slow_window=26, signal_window=9, wtype="exp", macd_wtype=None, signal_wtype=None, minp=None, macd_minp=None, signal_minp=None, adjust=False, macd_adjust=None, signal_adjust=None, ) class _MACD(MACD): """Moving Average Convergence Divergence (MACD). Is a trend-following momentum indicator that shows the relationship between two moving averages of prices. See [Moving Average Convergence Divergence – MACD](https://www.investopedia.com/terms/m/macd.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, macd_trace_kwargs: tp.KwargsLike = None, signal_trace_kwargs: tp.KwargsLike = None, hist_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `MACD.macd`, `MACD.signal` and `MACD.hist`. Args: column (str): Name of the column to plot. macd_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MACD.macd`. signal_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MACD.signal`. hist_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar` for `MACD.hist`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.MACD.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/MACD.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/MACD.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(bargap=0) fig.update_layout(**layout_kwargs) if macd_trace_kwargs is None: macd_trace_kwargs = {} if signal_trace_kwargs is None: signal_trace_kwargs = {} if hist_trace_kwargs is None: hist_trace_kwargs = {} macd_trace_kwargs = merge_dicts( dict(name="MACD", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), macd_trace_kwargs ) signal_trace_kwargs = merge_dicts( dict(name="Signal", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])), signal_trace_kwargs ) fig = self_col.macd.vbt.lineplot( trace_kwargs=macd_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.signal.vbt.lineplot( trace_kwargs=signal_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) # Plot hist hist = self_col.hist.values hist_diff = generic_nb.diff_1d_nb(hist) marker_colors = np.full(hist.shape, adjust_opacity("silver", 0.75), dtype=object) marker_colors[(hist > 0) & (hist_diff > 0)] = adjust_opacity("green", 0.75) marker_colors[(hist > 0) & (hist_diff <= 0)] = adjust_opacity("lightgreen", 0.75) marker_colors[(hist < 0) & (hist_diff < 0)] = adjust_opacity("red", 0.75) marker_colors[(hist < 0) & (hist_diff >= 0)] = adjust_opacity("lightcoral", 0.75) _hist_trace_kwargs = merge_dicts( dict( name="Histogram", x=self_col.hist.index, y=self_col.hist.values, marker_color=marker_colors, marker_line_width=0, ), hist_trace_kwargs, ) hist_bar = go.Bar(**_hist_trace_kwargs) if add_trace_kwargs is None: add_trace_kwargs = {} fig.add_trace(hist_bar, **add_trace_kwargs) return fig setattr(MACD, "__doc__", _MACD.__doc__) setattr(MACD, "plot", _MACD.plot) MACD.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `MSD`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "MSD", ] __pdoc__ = {} MSD = IndicatorFactory( class_name="MSD", module_name=__name__, input_names=["close"], param_names=["window", "wtype"], output_names=["msd"], ).with_apply_func( nb.msd_nb, kwargs_as_args=["minp", "adjust", "ddof"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="simple", minp=None, adjust=False, ddof=0, ) class _MSD(MSD): """Moving Standard Deviation (MSD). Standard deviation is an indicator that measures the size of an assets recent price moves in order to predict how volatile the price may be in the future.""" def plot( self, column: tp.Optional[tp.Label] = None, msd_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `MSD.msd`. Args: column (str): Name of the column to plot. msd_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MSD.msd`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.MSD.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/MSD.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/MSD.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if msd_trace_kwargs is None: msd_trace_kwargs = {} msd_trace_kwargs = merge_dicts( dict(name="MSD", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), msd_trace_kwargs, ) fig = self_col.msd.vbt.lineplot( trace_kwargs=msd_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **layout_kwargs, ) return fig setattr(MSD, "__doc__", _MSD.__doc__) setattr(MSD, "plot", _MSD.plot) MSD.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `OBV`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "OBV", ] __pdoc__ = {} OBV = IndicatorFactory( class_name="OBV", module_name=__name__, short_name="obv", input_names=["close", "volume"], param_names=[], output_names=["obv"], ).with_custom_func(nb.obv_nb) class _OBV(OBV): """On-balance volume (OBV). It relates price and volume in the stock market. OBV is based on a cumulative total volume. See [On-Balance Volume (OBV)](https://www.investopedia.com/terms/o/onbalancevolume.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, obv_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `OBV.obv`. Args: column (str): Name of the column to plot. obv_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OBV.obv`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```py >>> vbt.OBV.run(ohlcv['Close'], ohlcv['Volume']).plot().show() ``` ![](/assets/images/api/OBV.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/OBV.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if obv_trace_kwargs is None: obv_trace_kwargs = {} obv_trace_kwargs = merge_dicts( dict(name="OBV", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), obv_trace_kwargs, ) fig = self_col.obv.vbt.lineplot( trace_kwargs=obv_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(OBV, "__doc__", _OBV.__doc__) setattr(OBV, "plot", _OBV.plot) OBV.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `OLS`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_2d_array from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "OLS", ] __pdoc__ = {} OLS = IndicatorFactory( class_name="OLS", module_name=__name__, short_name="ols", input_names=["x", "y"], param_names=["window", "norm_window"], output_names=["slope", "intercept", "zscore"], lazy_outputs=dict( pred=lambda self: self.wrapper.wrap( nb.ols_pred_nb( to_2d_array(self.x), to_2d_array(self.slope), to_2d_array(self.intercept), ), ), error=lambda self: self.wrapper.wrap( nb.ols_error_nb( to_2d_array(self.y), to_2d_array(self.pred), ), ), angle=lambda self: self.wrapper.wrap( nb.ols_angle_nb( to_2d_array(self.slope), ), ), ), ).with_apply_func( nb.ols_nb, kwargs_as_args=["minp", "ddof", "with_zscore"], window=14, norm_window=None, minp=None, ddof=0, with_zscore=True, ) class _OLS(OLS): """Rolling Ordinary Least Squares (OLS). The indicator can be used to detect changes in the behavior of the stocks against the market or each other. See [The Linear Regression of Time and Price](https://www.investopedia.com/articles/trading/09/linear-regression-time-price.asp). """ def plot( self, column: tp.Optional[tp.Label] = None, plot_y: bool = True, y_trace_kwargs: tp.KwargsLike = None, pred_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `OLS.pred` against `OLS.y`. Args: column (str): Name of the column to plot. plot_y (bool): Whether to plot `OLS.y`. y_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OLS.y`. pred_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OLS.pred`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.OLS.run(np.arange(len(ohlcv)), ohlcv['Close']).plot().show() ``` ![](/assets/images/api/OLS.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/OLS.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if y_trace_kwargs is None: y_trace_kwargs = {} if pred_trace_kwargs is None: pred_trace_kwargs = {} y_trace_kwargs = merge_dicts( dict(name="Y", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), y_trace_kwargs, ) pred_trace_kwargs = merge_dicts( dict(name="Pred", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])), pred_trace_kwargs, ) if plot_y: fig = self_col.y.vbt.lineplot( trace_kwargs=y_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.pred.vbt.lineplot( trace_kwargs=pred_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig def plot_zscore( self, column: tp.Optional[tp.Label] = None, alpha: float = 0.05, zscore_trace_kwargs: tp.KwargsLike = None, add_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `OLS.zscore` with confidence intervals. Args: column (str): Name of the column to plot. alpha (float): The alpha level for the confidence interval. The default alpha = .05 returns a 95% confidence interval. zscore_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OLS.zscore`. add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both confidence intervals. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.OLS.run(np.arange(len(ohlcv)), ohlcv['Close']).plot_zscore().show() ``` ![](/assets/images/api/OLS_zscore.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/OLS_zscore.dark.svg#only-dark){: .iimg loading=lazy } """ import scipy.stats as st from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) zscore_trace_kwargs = merge_dicts( dict(name="Z-score", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), zscore_trace_kwargs, ) fig = self_col.zscore.vbt.lineplot( trace_kwargs=zscore_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) # Fill void between limits xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] add_shape_kwargs = merge_dicts( dict( type="rect", xref=xref, yref=yref, x0=self_col.wrapper.index[0], y0=st.norm.ppf(1 - alpha / 2), x1=self_col.wrapper.index[-1], y1=st.norm.ppf(alpha / 2), fillcolor="mediumslateblue", opacity=0.2, layer="below", line_width=0, ), add_shape_kwargs, ) fig.add_shape(**add_shape_kwargs) return fig setattr(OLS, "__doc__", _OLS.__doc__) setattr(OLS, "plot", _OLS.plot) setattr(OLS, "plot_zscore", _OLS.plot_zscore) OLS.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `PATSIM`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.generic import nb as generic_nb, enums as generic_enums from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "PATSIM", ] __pdoc__ = {} PATSIM = IndicatorFactory( class_name="PATSIM", module_name=__name__, short_name="patsim", input_names=["close"], param_names=[ "pattern", "window", "max_window", "row_select_prob", "window_select_prob", "interp_mode", "rescale_mode", "vmin", "vmax", "pmin", "pmax", "invert", "error_type", "distance_measure", "max_error", "max_error_interp_mode", "max_error_as_maxdist", "max_error_strict", "min_pct_change", "max_pct_change", "min_similarity", ], output_names=["similarity"], ).with_apply_func( generic_nb.rolling_pattern_similarity_nb, param_settings=dict( pattern=dict(is_array_like=True, min_one_dim=True), interp_mode=dict( dtype=generic_enums.InterpMode, post_index_func=lambda index: index.str.lower(), ), rescale_mode=dict( dtype=generic_enums.RescaleMode, post_index_func=lambda index: index.str.lower(), ), error_type=dict( dtype=generic_enums.ErrorType, post_index_func=lambda index: index.str.lower(), ), distance_measure=dict( dtype=generic_enums.DistanceMeasure, post_index_func=lambda index: index.str.lower(), ), max_error=dict(is_array_like=True, min_one_dim=True), max_error_interp_mode=dict( dtype=generic_enums.InterpMode, post_index_func=lambda index: index.str.lower(), ), ), window=None, max_window=None, row_select_prob=1.0, window_select_prob=1.0, interp_mode="mixed", rescale_mode="minmax", vmin=np.nan, vmax=np.nan, pmin=np.nan, pmax=np.nan, invert=False, error_type="absolute", distance_measure="mae", max_error=np.nan, max_error_interp_mode=None, max_error_as_maxdist=False, max_error_strict=False, min_pct_change=np.nan, max_pct_change=np.nan, min_similarity=np.nan, ) class _PATSIM(PATSIM): """Rolling pattern similarity. Based on `vectorbtpro.generic.nb.rolling.rolling_pattern_similarity_nb`.""" def plot( self, column: tp.Optional[tp.Label] = None, similarity_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `PATSIM.similarity` against `PATSIM.close`. Args: column (str): Name of the column to plot. similarity_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PATSIM.similarity`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.PATSIM.run(ohlcv['Close'], np.array([1, 2, 3, 2, 1]), 30).plot().show() ``` ![](/assets/images/api/PATSIM.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/PATSIM.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) similarity_trace_kwargs = merge_dicts( dict(name="Similarity", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), similarity_trace_kwargs, ) fig = self_col.similarity.vbt.lineplot( trace_kwargs=similarity_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] default_layout = dict() default_layout[yaxis] = dict(tickformat=",.0%") fig.update_layout(**default_layout) fig.update_layout(**layout_kwargs) return fig def overlay_with_heatmap( self, column: tp.Optional[tp.Label] = None, close_trace_kwargs: tp.KwargsLike = None, similarity_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Overlay `PATSIM.similarity` as a heatmap on top of `PATSIM.close`. Args: column (str): Name of the column to plot. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PATSIM.close`. similarity_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap` for `PATSIM.similarity`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.PATSIM.run(ohlcv['Close'], np.array([1, 2, 3, 2, 1]), 30).overlay_with_heatmap().show() ``` ![](/assets/images/api/PATSIM_heatmap.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/PATSIM_heatmap.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if close_trace_kwargs is None: close_trace_kwargs = {} if similarity_trace_kwargs is None: similarity_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) similarity_trace_kwargs = merge_dicts( dict( colorbar=dict(tickformat=",.0%"), colorscale=[ [0.0, "rgba(0, 0, 0, 0)"], [1.0, plotting_cfg["color_schema"]["lightpurple"]], ], zmin=0, zmax=1, ), similarity_trace_kwargs, ) fig = self_col.close.vbt.overlay_with_heatmap( self_col.similarity, trace_kwargs=close_trace_kwargs, heatmap_kwargs=dict(y_labels=["Similarity"], trace_kwargs=similarity_trace_kwargs), add_trace_kwargs=add_trace_kwargs, fig=fig, **layout_kwargs, ) return fig setattr(PATSIM, "__doc__", _PATSIM.__doc__) setattr(PATSIM, "plot", _PATSIM.plot) setattr(PATSIM, "overlay_with_heatmap", _PATSIM.overlay_with_heatmap) PATSIM.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `PIVOTINFO`.""" import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_2d_array from vectorbtpro.indicators import nb from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.indicators.enums import Pivot, TrendMode from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "PIVOTINFO", ] __pdoc__ = {} PIVOTINFO = IndicatorFactory( class_name="PIVOTINFO", module_name=__name__, short_name="pivotinfo", input_names=["high", "low"], param_names=["up_th", "down_th"], output_names=["conf_pivot", "conf_idx", "last_pivot", "last_idx"], lazy_outputs=dict( conf_value=lambda self: self.wrapper.wrap( nb.pivot_value_nb( to_2d_array(self.high), to_2d_array(self.low), to_2d_array(self.conf_pivot), to_2d_array(self.conf_idx), ) ), last_value=lambda self: self.wrapper.wrap( nb.pivot_value_nb( to_2d_array(self.high), to_2d_array(self.low), to_2d_array(self.last_pivot), to_2d_array(self.last_idx), ) ), pivots=lambda self: self.wrapper.wrap( nb.pivots_nb( to_2d_array(self.conf_pivot), to_2d_array(self.conf_idx), to_2d_array(self.last_pivot), ) ), modes=lambda self: self.wrapper.wrap( nb.modes_nb( to_2d_array(self.pivots), ) ), ), attr_settings=dict( conf_pivot=dict(dtype=Pivot, enum_unkval=0), last_pivot=dict(dtype=Pivot, enum_unkval=0), pivots=dict(dtype=Pivot, enum_unkval=0), modes=dict(dtype=TrendMode, enum_unkval=0), ), ).with_apply_func( nb.pivot_info_nb, param_settings=dict( up_th=flex_elem_param_config, down_th=flex_elem_param_config, ), ) class _PIVOTINFO(PIVOTINFO): """Indicator that returns various information on pivots identified based on thresholds. * `conf_pivot` (`vectorbtpro.indicators.enums.Pivot`): the type of the latest confirmed pivot (running) * `conf_idx`: the index of the latest confirmed pivot (running) * `conf_value`: the high/low value under the latest confirmed pivot (running) * `last_pivot` (`vectorbtpro.indicators.enums.Pivot`): the type of the latest pivot (running) * `last_idx`: the index of the latest pivot (running) * `last_value`: the high/low value under the latest pivot (running) * `pivots` (`vectorbtpro.indicators.enums.Pivot`): confirmed pivots stored under their indices (looking ahead - use only for plotting!) * `modes` (`vectorbtpro.indicators.enums.TrendMode`): modes between confirmed pivot points (looking ahead - use only for plotting!) """ def plot( self, column: tp.Optional[tp.Label] = None, conf_value_trace_kwargs: tp.KwargsLike = None, last_value_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `PIVOTINFO.conf_value` and `PIVOTINFO.last_value`. Args: column (str): Name of the column to plot. conf_value_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PIVOTINFO.conf_value` line. last_value_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PIVOTINFO.last_value` line. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> fig = ohlcv.vbt.ohlcv.plot() >>> vbt.PIVOTINFO.run(ohlcv['High'], ohlcv['Low'], 0.1, 0.1).plot(fig=fig).show() ``` ![](/assets/images/api/PIVOTINFO.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/PIVOTINFO.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if conf_value_trace_kwargs is None: conf_value_trace_kwargs = {} if last_value_trace_kwargs is None: last_value_trace_kwargs = {} conf_value_trace_kwargs = merge_dicts( dict(name="Confirmed value", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), conf_value_trace_kwargs, ) last_value_trace_kwargs = merge_dicts( dict(name="Last value", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])), last_value_trace_kwargs, ) fig = self_col.conf_value.vbt.lineplot( trace_kwargs=conf_value_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.last_value.vbt.lineplot( trace_kwargs=last_value_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig def plot_zigzag( self, column: tp.Optional[tp.Label] = None, zigzag_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot zig-zag line. Args: column (str): Name of the column to plot. zigzag_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for zig-zag line. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> fig = ohlcv.vbt.ohlcv.plot() >>> vbt.PIVOTINFO.run(ohlcv['High'], ohlcv['Low'], 0.1, 0.1).plot_zigzag(fig=fig).show() ``` ![](/assets/images/api/PIVOTINFO_zigzag.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/PIVOTINFO_zigzag.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if zigzag_trace_kwargs is None: zigzag_trace_kwargs = {} zigzag_trace_kwargs = merge_dicts( dict(name="ZigZag", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), zigzag_trace_kwargs, ) pivots = self_col.pivots highs = self_col.high[pivots == Pivot.Peak] lows = self_col.low[pivots == Pivot.Valley] fig = ( pd.concat((highs, lows)) .sort_index() .vbt.lineplot( trace_kwargs=zigzag_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) ) return fig setattr(PIVOTINFO, "__doc__", _PIVOTINFO.__doc__) setattr(PIVOTINFO, "plot", _PIVOTINFO.plot) setattr(PIVOTINFO, "plot_zigzag", _PIVOTINFO.plot_zigzag) PIVOTINFO.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RSI`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "RSI", ] __pdoc__ = {} RSI = IndicatorFactory( class_name="RSI", module_name=__name__, input_names=["close"], param_names=["window", "wtype"], output_names=["rsi"], ).with_apply_func( nb.rsi_nb, kwargs_as_args=["minp", "adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="wilder", minp=None, adjust=False, ) class _RSI(RSI): """Relative Strength Index (RSI). Compares the magnitude of recent gains and losses over a specified time period to measure speed and change of price movements of a security. It is primarily used to attempt to identify overbought or oversold conditions in the trading of an asset. See [Relative Strength Index (RSI)](https://www.investopedia.com/terms/r/rsi.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, limits: tp.Tuple[float, float] = (30, 70), rsi_trace_kwargs: tp.KwargsLike = None, add_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `RSI.rsi`. Args: column (str): Name of the column to plot. limits (tuple of float): Tuple of the lower and upper limit. rsi_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `RSI.rsi`. add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both limits. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.RSI.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/RSI.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/RSI.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if rsi_trace_kwargs is None: rsi_trace_kwargs = {} rsi_trace_kwargs = merge_dicts( dict(name="RSI", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), rsi_trace_kwargs, ) fig = self_col.rsi.vbt.lineplot( trace_kwargs=rsi_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] default_layout = dict() default_layout[yaxis] = dict(range=[-5, 105]) fig.update_layout(**default_layout) fig.update_layout(**layout_kwargs) # Fill void between limits add_shape_kwargs = merge_dicts( dict( type="rect", xref=xref, yref=yref, x0=self_col.wrapper.index[0], y0=limits[0], x1=self_col.wrapper.index[-1], y1=limits[1], fillcolor="mediumslateblue", opacity=0.2, layer="below", line_width=0, ), add_shape_kwargs, ) fig.add_shape(**add_shape_kwargs) return fig setattr(RSI, "__doc__", _RSI.__doc__) setattr(RSI, "plot", _RSI.plot) RSI.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `SIGDET`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators import nb from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.colors import adjust_opacity from vectorbtpro.utils.config import merge_dicts __all__ = [ "SIGDET", ] __pdoc__ = {} SIGDET = IndicatorFactory( class_name="SIGDET", module_name=__name__, short_name="sigdet", input_names=["close"], param_names=["lag", "factor", "influence", "up_factor", "down_factor", "mean_influence", "std_influence"], output_names=["signal", "upper_band", "lower_band"], ).with_apply_func( nb.signal_detection_nb, param_settings=dict( factor=flex_elem_param_config, influence=flex_elem_param_config, up_factor=flex_elem_param_config, down_factor=flex_elem_param_config, mean_influence=flex_elem_param_config, std_influence=flex_elem_param_config, ), lag=14, factor=1.0, influence=1.0, up_factor=None, down_factor=None, mean_influence=None, std_influence=None, ) class _SIGDET(SIGDET): """Robust peak detection algorithm (using z-scores). See https://stackoverflow.com/a/22640362""" def plot( self, column: tp.Optional[tp.Label] = None, signal_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `SIGDET.signal` against `SIGDET.close`. Args: column (str): Name of the column to plot. signal_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.signal`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.SIGDET.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/SIGDET.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/SIGDET.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) signal_trace_kwargs = merge_dicts( dict(name="Signal", line=dict(color=plotting_cfg["color_schema"]["lightblue"], shape="hv")), signal_trace_kwargs, ) fig = self_col.signal.vbt.lineplot( trace_kwargs=signal_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **layout_kwargs, ) return fig def plot_bands( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, upper_band_trace_kwargs: tp.KwargsLike = None, lower_band_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `SIGDET.upper_band` and `SIGDET.lower_band` against `SIGDET.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `SIGDET.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.close`. upper_band_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.upper_band`. lower_band_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.lower_band`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.SIGDET.run(ohlcv['Close']).plot_bands().show() ``` ![](/assets/images/api/SIGDET_plot_bands.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/SIGDET_plot_bands.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if upper_band_trace_kwargs is None: upper_band_trace_kwargs = {} if lower_band_trace_kwargs is None: lower_band_trace_kwargs = {} lower_band_trace_kwargs = merge_dicts( dict( name="Lower band", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)), ), lower_band_trace_kwargs, ) upper_band_trace_kwargs = merge_dicts( dict( name="Upper band", line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)), fill="tonexty", fillcolor="rgba(128, 128, 128, 0.2)", ), upper_band_trace_kwargs, ) # default kwargs close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) fig = self_col.lower_band.vbt.lineplot( trace_kwargs=lower_band_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.upper_band.vbt.lineplot( trace_kwargs=upper_band_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(SIGDET, "__doc__", _SIGDET.__doc__) setattr(SIGDET, "plot", _SIGDET.plot) setattr(SIGDET, "plot_bands", _SIGDET.plot_bands) SIGDET.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `STOCH`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "STOCH", ] __pdoc__ = {} STOCH = IndicatorFactory( class_name="STOCH", module_name=__name__, input_names=["high", "low", "close"], param_names=["fast_k_window", "slow_k_window", "slow_d_window", "wtype", "slow_k_wtype", "slow_d_wtype"], output_names=["fast_k", "slow_k", "slow_d"], ).with_apply_func( nb.stoch_nb, kwargs_as_args=["minp", "fast_k_minp", "slow_k_minp", "slow_d_minp", "adjust", "slow_k_adjust", "slow_d_adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ), slow_k_wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ), slow_d_wtype=dict( dtype=generic_enums.WType, dtype_kwargs=dict(enum_unkval=None), post_index_func=lambda index: index.str.lower(), ), ), fast_k_window=14, slow_k_window=3, slow_d_window=3, wtype="simple", slow_k_wtype=None, slow_d_wtype=None, minp=None, fast_k_minp=None, slow_k_minp=None, slow_d_minp=None, adjust=False, slow_k_adjust=None, slow_d_adjust=None, ) class _STOCH(STOCH): """Stochastic Oscillator (STOCH). A stochastic oscillator is a momentum indicator comparing a particular closing price of a security to a range of its prices over a certain period of time. It is used to generate overbought and oversold trading signals, utilizing a 0-100 bounded range of values. See [Stochastic Oscillator](https://www.investopedia.com/terms/s/stochasticoscillator.asp).""" def plot( self, column: tp.Optional[tp.Label] = None, limits: tp.Tuple[float, float] = (20, 80), fast_k_trace_kwargs: tp.KwargsLike = None, slow_k_trace_kwargs: tp.KwargsLike = None, slow_d_trace_kwargs: tp.KwargsLike = None, add_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `STOCH.slow_k` and `STOCH.slow_d`. Args: column (str): Name of the column to plot. limits (tuple of float): Tuple of the lower and upper limit. fast_k_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `STOCH.fast_k`. slow_k_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `STOCH.slow_k`. slow_d_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `STOCH.slow_d`. add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both limits. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.STOCH.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show() ``` ![](/assets/images/api/STOCH.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/STOCH.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fast_k_trace_kwargs is None: fast_k_trace_kwargs = {} if slow_k_trace_kwargs is None: slow_k_trace_kwargs = {} if slow_d_trace_kwargs is None: slow_d_trace_kwargs = {} fast_k_trace_kwargs = merge_dicts( dict(name="Fast %K", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), fast_k_trace_kwargs, ) slow_k_trace_kwargs = merge_dicts( dict(name="Slow %K", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])), slow_k_trace_kwargs, ) slow_d_trace_kwargs = merge_dicts( dict(name="Slow %D", line=dict(color=plotting_cfg["color_schema"]["lightpink"])), slow_d_trace_kwargs, ) fig = self_col.fast_k.vbt.lineplot( trace_kwargs=fast_k_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.slow_k.vbt.lineplot( trace_kwargs=slow_k_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.slow_d.vbt.lineplot( trace_kwargs=slow_d_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] default_layout = dict() default_layout[yaxis] = dict(range=[-5, 105]) fig.update_layout(**default_layout) fig.update_layout(**layout_kwargs) # Fill void between limits add_shape_kwargs = merge_dicts( dict( type="rect", xref=xref, yref=yref, x0=self_col.wrapper.index[0], y0=limits[0], x1=self_col.wrapper.index[-1], y1=limits[1], fillcolor="mediumslateblue", opacity=0.2, layer="below", line_width=0, ), add_shape_kwargs, ) fig.add_shape(**add_shape_kwargs) return fig setattr(STOCH, "__doc__", _STOCH.__doc__) setattr(STOCH, "plot", _STOCH.plot) STOCH.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `SUPERTREND`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts __all__ = [ "SUPERTREND", ] __pdoc__ = {} SUPERTREND = IndicatorFactory( class_name="SUPERTREND", module_name=__name__, short_name="supertrend", input_names=["high", "low", "close"], param_names=["period", "multiplier"], output_names=["trend", "direction", "long", "short"], ).with_apply_func(nb.supertrend_nb, period=7, multiplier=3) class _SUPERTREND(SUPERTREND): """Supertrend indicator.""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, superl_trace_kwargs: tp.KwargsLike = None, supers_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `SUPERTREND.long` and `SUPERTREND.short` against `SUPERTREND.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `SUPERTREND.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SUPERTREND.close`. superl_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SUPERTREND.long`. supers_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SUPERTREND.short`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.SUPERTREND.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show() ``` ![](/assets/images/api/SUPERTREND.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/SUPERTREND.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if superl_trace_kwargs is None: superl_trace_kwargs = {} if supers_trace_kwargs is None: supers_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) superl_trace_kwargs = merge_dicts( dict(name="Long", line=dict(color=plotting_cfg["color_schema"]["green"])), superl_trace_kwargs, ) supers_trace_kwargs = merge_dicts( dict(name="Short", line=dict(color=plotting_cfg["color_schema"]["red"])), supers_trace_kwargs, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.long.vbt.lineplot( trace_kwargs=superl_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.short.vbt.lineplot( trace_kwargs=supers_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(SUPERTREND, "__doc__", _SUPERTREND.__doc__) setattr(SUPERTREND, "plot", _SUPERTREND.plot) SUPERTREND.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `VWAP`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.indicators import nb from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.template import RepFunc __all__ = [ "VWAP", ] __pdoc__ = {} def substitute_anchor(wrapper: ArrayWrapper, anchor: tp.Optional[tp.FrequencyLike]) -> tp.Array1d: """Substitute reset frequency by group lens.""" if anchor is None: return np.array([wrapper.shape[0]]) return wrapper.get_index_grouper(anchor).get_group_lens() VWAP = IndicatorFactory( class_name="VWAP", module_name=__name__, short_name="vwap", input_names=["high", "low", "close", "volume"], param_names=["anchor"], output_names=["vwap"], ).with_apply_func( nb.vwap_nb, param_settings=dict( anchor=dict(template=RepFunc(substitute_anchor)), ), anchor="D", ) class _VWAP(VWAP): """Volume-Weighted Average Price (VWAP). VWAP is a technical analysis indicator used on intraday charts that resets at the start of every new trading session. See [Volume-Weighted Average Price (VWAP)](https://www.investopedia.com/terms/v/vwap.asp). Anchor can be any index grouper.""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, vwap_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `VWAP.vwap` against `VWAP.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `VWAP.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `VWAP.close`. vwap_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `VWAP.vwap`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.VWAP.run( ... ohlcv['High'], ... ohlcv['Low'], ... ohlcv['Close'], ... ohlcv['Volume'], ... anchor="W" ... ).plot().show() ``` ![](/assets/images/api/VWAP.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/VWAP.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if vwap_trace_kwargs is None: vwap_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) vwap_trace_kwargs = merge_dicts( dict(name="VWAP", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), vwap_trace_kwargs, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.vwap.vbt.lineplot( trace_kwargs=vwap_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(VWAP, "__doc__", _VWAP.__doc__) setattr(VWAP, "plot", _VWAP.plot) VWAP.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for building and running indicators. Technical indicators are used to see past trends and anticipate future moves. See [Using Technical Indicators to Develop Trading Strategies](https://www.investopedia.com/articles/trading/11/indicators-and-strategies-explained.asp).""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.indicators.configs import * from vectorbtpro.indicators.custom import * from vectorbtpro.indicators.expr import * from vectorbtpro.indicators.factory import * from vectorbtpro.indicators.nb import * from vectorbtpro.indicators.talib_ import * __exclude_from__all__ = [ "enums", ] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Configs for custom indicators.""" from vectorbtpro.utils.config import ReadonlyConfig __all__ = [ "flex_col_param_config", "flex_elem_param_config", ] flex_elem_param_config = ReadonlyConfig( dict( is_array_like=True, bc_to_input=True, broadcast_kwargs=dict(keep_flex=True, min_ndim=2), ) ) """Config for flexible element-wise parameters.""" flex_row_param_config = ReadonlyConfig( dict( is_array_like=True, bc_to_input=0, broadcast_kwargs=dict(keep_flex=True, min_ndim=1), ) ) """Config for flexible row-wise parameters.""" flex_col_param_config = ReadonlyConfig( dict( is_array_like=True, bc_to_input=1, per_column=True, broadcast_kwargs=dict(keep_flex=True, min_ndim=1), ) ) """Config for flexible column-wise parameters.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for indicators.""" from vectorbtpro import _typing as tp from vectorbtpro.utils.formatting import prettify __pdoc__all__ = __all__ = [ "Pivot", "TrendMode", "HurstMethod", "SuperTrendAIS", "SuperTrendAOS", ] __pdoc__ = {} # ############# Enums ############# # class PivotT(tp.NamedTuple): Valley: int = -1 Peak: int = 1 Pivot = PivotT() """_""" __pdoc__[ "Pivot" ] = f"""Pivot. ```python {prettify(Pivot)} ``` """ class TrendModeT(tp.NamedTuple): Downtrend: int = -1 Uptrend: int = 1 TrendMode = TrendModeT() """_""" __pdoc__[ "TrendMode" ] = f"""Trend mode. ```python {prettify(TrendMode)} ``` """ class HurstMethodT(tp.NamedTuple): Standard: int = 0 LogRS: int = 1 RS: int = 2 DMA: int = 3 DSOD: int = 4 HurstMethod = HurstMethodT() """_""" __pdoc__[ "HurstMethod" ] = f"""Hurst method. ```python {prettify(HurstMethod)} ``` """ # ############# States ############# # class SuperTrendAIS(tp.NamedTuple): i: int high: float low: float close: float prev_close: float prev_upper: float prev_lower: float prev_direction: int nobs: int weighted_avg: float old_wt: float period: int multiplier: float __pdoc__[ "SuperTrendAIS" ] = """A named tuple representing the input state of `vectorbtpro.indicators.nb.supertrend_acc_nb`.""" class SuperTrendAOS(tp.NamedTuple): nobs: int weighted_avg: float old_wt: float upper: float lower: float trend: float direction: int long: float short: float __pdoc__[ "SuperTrendAOS" ] = """A named tuple representing the output state of `vectorbtpro.indicators.nb.supertrend_acc_nb`.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Functions and config for evaluating indicator expressions.""" import math import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.grouping.base import Grouper from vectorbtpro.generic.nb import ( fshift_nb, diff_nb, rank_nb, rolling_sum_nb, rolling_mean_nb, rolling_std_nb, wm_mean_nb, rolling_rank_nb, rolling_prod_nb, rolling_min_nb, rolling_max_nb, rolling_argmin_nb, rolling_argmax_nb, rolling_cov_nb, rolling_corr_nb, demean_nb, ) from vectorbtpro.indicators.nb import vwap_nb from vectorbtpro.returns.nb import returns_nb from vectorbtpro.utils.config import HybridConfig __all__ = [] # ############# Delay ############# # def delay(x: tp.Array2d, d: float) -> tp.Array2d: """Value of `x` `d` days ago.""" return fshift_nb(x, math.floor(d)) def delta(x: tp.Array2d, d: float) -> tp.Array2d: """Today’s value of `x` minus the value of `x` `d` days ago.""" return diff_nb(x, math.floor(d)) # ############# Cross-section ############# # def cs_rescale(x: tp.Array2d) -> tp.Array2d: """Rescale `x` such that `sum(abs(x)) = 1`.""" return (x.T / np.abs(x).sum(axis=1)).T def cs_rank(x: tp.Array2d) -> tp.Array2d: """Rank cross-sectionally.""" return rank_nb(x.T, pct=True).T def cs_demean(x: tp.Array2d, g: tp.GroupByLike, context: tp.KwargsLike = None) -> tp.Array2d: """Demean `x` against groups `g` cross-sectionally.""" group_map = Grouper(context["wrapper"].columns, g).get_group_map() return demean_nb(x, group_map) # ############# Rolling ############# # def ts_min(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling min.""" return rolling_min_nb(x, math.floor(d)) def ts_max(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling max.""" return rolling_max_nb(x, math.floor(d)) def ts_argmin(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling argmin.""" argmin = rolling_argmin_nb(x, math.floor(d), local=True) if -1 in argmin: argmin = np.where(argmin != -1, argmin, np.nan) return np.add(argmin, 1) def ts_argmax(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling argmax.""" argmax = rolling_argmax_nb(x, math.floor(d), local=True) if -1 in argmax: argmax = np.where(argmax != -1, argmax, np.nan) return np.add(argmax, 1) def ts_rank(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling rank.""" return rolling_rank_nb(x, math.floor(d), pct=True) def ts_sum(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling sum.""" return rolling_sum_nb(x, math.floor(d)) def ts_product(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling product.""" return rolling_prod_nb(x, math.floor(d)) def ts_mean(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling mean.""" return rolling_mean_nb(x, math.floor(d)) def ts_wmean(x: tp.Array2d, d: float) -> tp.Array2d: """Weighted moving average over the past `d` days with linearly decaying weight.""" return wm_mean_nb(x, math.floor(d)) def ts_std(x: tp.Array2d, d: float) -> tp.Array2d: """Return the rolling standard deviation.""" return rolling_std_nb(x, math.floor(d)) def ts_corr(x: tp.Array2d, y: tp.Array2d, d: float) -> tp.Array2d: """Time-serial correlation of `x` and `y` for the past `d` days.""" return rolling_corr_nb(x, y, math.floor(d)) def ts_cov(x: tp.Array2d, y: tp.Array2d, d: float) -> tp.Array2d: """Time-serial covariance of `x` and `y` for the past `d` days.""" return rolling_cov_nb(x, y, math.floor(d)) def adv(d: float, context: tp.KwargsLike = None) -> tp.Array2d: """Average daily dollar volume for the past `d` days.""" return ts_mean(context["volume"], math.floor(d)) # ############# Substitutions ############# # def returns(context: tp.KwargsLike = None) -> tp.Array2d: """Daily close-to-close returns.""" return returns_nb(context["close"]) def vwap(context: tp.KwargsLike = None) -> tp.Array2d: """VWAP.""" if isinstance(context["wrapper"].index, pd.DatetimeIndex): group_lens = context["wrapper"].get_index_grouper("D").get_group_lens() else: group_lens = np.array([context["wrapper"].shape[0]]) return vwap_nb(context["high"], context["low"], context["close"], context["volume"], group_lens) def cap(context: tp.KwargsLike = None) -> tp.Array2d: """Market capitalization.""" return context["close"] * context["volume"] # ############# Configs ############# # __pdoc__ = {} expr_func_config = HybridConfig( dict( delay=dict(func=delay), delta=dict(func=delta), cs_rescale=dict(func=cs_rescale), cs_rank=dict(func=cs_rank), cs_demean=dict(func=cs_demean), ts_min=dict(func=ts_min), ts_max=dict(func=ts_max), ts_argmin=dict(func=ts_argmin), ts_argmax=dict(func=ts_argmax), ts_rank=dict(func=ts_rank), ts_sum=dict(func=ts_sum), ts_product=dict(func=ts_product), ts_mean=dict(func=ts_mean), ts_wmean=dict(func=ts_wmean), ts_std=dict(func=ts_std), ts_corr=dict(func=ts_corr), ts_cov=dict(func=ts_cov), adv=dict(func=adv, magnet_inputs=["volume"]), ) ) """_""" __pdoc__[ "expr_func_config" ] = f"""Config for functions used in indicator expressions. Can be modified. ```python {expr_func_config.prettify()} ``` """ expr_res_func_config = HybridConfig( dict( returns=dict(func=returns, magnet_inputs=["close"]), vwap=dict(func=vwap, magnet_inputs=["high", "low", "close", "volume"]), cap=dict(func=cap, magnet_inputs=["close", "volume"]), ) ) """_""" __pdoc__[ "expr_res_func_config" ] = f"""Config for resolvable functions used in indicator expressions. Can be modified. ```python {expr_res_func_config.prettify()} ``` """ wqa101_expr_config = HybridConfig( { 1: "cs_rank(ts_argmax(power(where(returns < 0, ts_std(returns, 20), close), 2.), 5)) - 0.5", 2: "-ts_corr(cs_rank(delta(log(volume), 2)), cs_rank((close - open) / open), 6)", 3: "-ts_corr(cs_rank(open), cs_rank(volume), 10)", 4: "-ts_rank(cs_rank(low), 9)", 5: "cs_rank(open - (ts_sum(vwap, 10) / 10)) * (-abs(cs_rank(close - vwap)))", 6: "-ts_corr(open, volume, 10)", 7: "where(adv(20) < volume, (-ts_rank(abs(delta(close, 7)), 60)) * sign(delta(close, 7)), -1)", 8: "-cs_rank((ts_sum(open, 5) * ts_sum(returns, 5)) - delay(ts_sum(open, 5) * ts_sum(returns, 5), 10))", 9: ( "where(0 < ts_min(delta(close, 1), 5), delta(close, 1), where(ts_max(delta(close, 1), 5) < 0, delta(close," " 1), -delta(close, 1)))" ), 10: ( "cs_rank(where(0 < ts_min(delta(close, 1), 4), delta(close, 1), where(ts_max(delta(close, 1), 4) < 0," " delta(close, 1), -delta(close, 1))))" ), 11: "(cs_rank(ts_max(vwap - close, 3)) + cs_rank(ts_min(vwap - close, 3))) * cs_rank(delta(volume, 3))", 12: "sign(delta(volume, 1)) * (-delta(close, 1))", 13: "-cs_rank(ts_cov(cs_rank(close), cs_rank(volume), 5))", 14: "(-cs_rank(delta(returns, 3))) * ts_corr(open, volume, 10)", 15: "-ts_sum(cs_rank(ts_corr(cs_rank(high), cs_rank(volume), 3)), 3)", 16: "-cs_rank(ts_cov(cs_rank(high), cs_rank(volume), 5))", 17: ( "((-cs_rank(ts_rank(close, 10))) * cs_rank(delta(delta(close, 1), 1))) * cs_rank(ts_rank(volume /" " adv(20), 5))" ), 18: "-cs_rank((ts_std(abs(close - open), 5) + (close - open)) + ts_corr(close, open, 10))", 19: "(-sign((close - delay(close, 7)) + delta(close, 7))) * (1 + cs_rank(1 + ts_sum(returns, 250)))", 20: "((-cs_rank(open - delay(high, 1))) * cs_rank(open - delay(close, 1))) * cs_rank(open - delay(low, 1))", 21: ( "where(((ts_sum(close, 8) / 8) + ts_std(close, 8)) < (ts_sum(close, 2) / 2), -1, where((ts_sum(close, 2) /" " 2) < ((ts_sum(close, 8) / 8) - ts_std(close, 8)), 1, where(volume / adv(20) >= 1, 1, -1)))" ), 22: "-(delta(ts_corr(high, volume, 5), 5) * cs_rank(ts_std(close, 20)))", 23: "where((ts_sum(high, 20) / 20) < high, -delta(high, 2), 0)", 24: ( "where((delta(ts_sum(close, 100) / 100, 100) / delay(close, 100)) <= 0.05, (-(close - ts_min(close, 100)))," " -delta(close, 3))" ), 25: "cs_rank((((-returns) * adv(20)) * vwap) * (high - close))", 26: "-ts_max(ts_corr(ts_rank(volume, 5), ts_rank(high, 5), 5), 3)", 27: "where(0.5 < cs_rank(ts_sum(ts_corr(cs_rank(volume), cs_rank(vwap), 6), 2) / 2.0), -1, 1)", 28: "cs_rescale((ts_corr(adv(20), low, 5) + ((high + low) / 2)) - close)", 29: ( "ts_min(ts_product(cs_rank(cs_rank(cs_rescale(log(ts_sum(ts_min(cs_rank(cs_rank(-cs_rank(delta(close - 1," " 5)))), 2), 1))))), 1), 5) + ts_rank(delay(-returns, 6), 5)" ), 30: ( "((1.0 - cs_rank((sign(close - delay(close, 1)) + sign(delay(close, 1) - delay(close, 2))) +" " sign(delay(close, 2) - delay(close, 3)))) * ts_sum(volume, 5)) / ts_sum(volume, 20)" ), 31: ( "(cs_rank(cs_rank(cs_rank(ts_wmean(-cs_rank(cs_rank(delta(close, 10))), 10)))) + cs_rank(-delta(close, 3)))" " + sign(cs_rescale(ts_corr(adv(20), low, 12)))" ), 32: "cs_rescale((ts_sum(close, 7) / 7) - close) + (20 * cs_rescale(ts_corr(vwap, delay(close, 5), 230)))", 33: "cs_rank(-(1 - (open / close)))", 34: "cs_rank((1 - cs_rank(ts_std(returns, 2) / ts_std(returns, 5))) + (1 - cs_rank(delta(close, 1))))", 35: "(ts_rank(volume, 32) * (1 - ts_rank((close + high) - low, 16))) * (1 - ts_rank(returns, 32))", 36: ( "((((2.21 * cs_rank(ts_corr(close - open, delay(volume, 1), 15))) + (0.7 * cs_rank(open - close))) + (0.73" " * cs_rank(ts_rank(delay(-returns, 6), 5)))) + cs_rank(abs(ts_corr(vwap, adv(20), 6)))) + (0.6 *" " cs_rank(((ts_sum(close, 200) / 200) - open) * (close - open)))" ), 37: "cs_rank(ts_corr(delay(open - close, 1), close, 200)) + cs_rank(open - close)", 38: "(-cs_rank(ts_rank(close, 10))) * cs_rank(close / open)", 39: ( "(-cs_rank(delta(close, 7) * (1 - cs_rank(ts_wmean(volume / adv(20), 9))))) * (1 + cs_rank(ts_sum(returns," " 250)))" ), 40: "(-cs_rank(ts_std(high, 10))) * ts_corr(high, volume, 10)", 41: "((high * low) ** 0.5) - vwap", 42: "cs_rank(vwap - close) / cs_rank(vwap + close)", 43: "ts_rank(volume / adv(20), 20) * ts_rank(-delta(close, 7), 8)", 44: "-ts_corr(high, cs_rank(volume), 5)", 45: ( "-((cs_rank(ts_sum(delay(close, 5), 20) / 20) * ts_corr(close, volume, 2)) * cs_rank(ts_corr(ts_sum(close," " 5), ts_sum(close, 20), 2)))" ), 46: ( "where(0.25 < (((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)), -1," " where((((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)) < 0, 1, -(close" " - delay(close, 1))))" ), 47: ( "(((cs_rank(1 / close) * volume) / adv(20)) * ((high * cs_rank(high - close)) / (ts_sum(high, 5) / 5))) -" " cs_rank(vwap - delay(vwap, 5))" ), 48: ( "cs_demean((ts_corr(delta(close, 1), delta(delay(close, 1), 1), 250) * delta(close, 1)) / close," " 'subindustry') / ts_sum((delta(close, 1) / delay(close, 1)) ** 2, 250)" ), 49: ( "where((((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)) < (-0.1), 1," " -(close - delay(close, 1)))" ), 50: "-ts_max(cs_rank(ts_corr(cs_rank(volume), cs_rank(vwap), 5)), 5)", 51: ( "where((((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)) < (-0.05), 1," " -(close - delay(close, 1)))" ), 52: ( "(((-ts_min(low, 5)) + delay(ts_min(low, 5), 5)) * cs_rank((ts_sum(returns, 240) - ts_sum(returns, 20)) /" " 220)) * ts_rank(volume, 5)" ), 53: "-delta(((close - low) - (high - close)) / (close - low), 9)", 54: "(-((low - close) * (open ** 5))) / ((low - high) * (close ** 5))", 55: "-ts_corr(cs_rank((close - ts_min(low, 12)) / (ts_max(high, 12) - ts_min(low, 12))), cs_rank(volume), 6)", 56: "0 - (1 * (cs_rank(ts_sum(returns, 10) / ts_sum(ts_sum(returns, 2), 3)) * cs_rank(returns * cap)))", 57: "0 - (1 * ((close - vwap) / ts_wmean(cs_rank(ts_argmax(close, 30)), 2)))", 58: "-ts_rank(ts_wmean(ts_corr(cs_demean(vwap, 'sector'), volume, 3.92795), 7.89291), 5.50322)", 59: ( "-ts_rank(ts_wmean(ts_corr(cs_demean((vwap * 0.728317) + (vwap * (1 - 0.728317)), 'industry'), volume," " 4.25197), 16.2289), 8.19648)" ), 60: ( "0 - (1 * ((2 * cs_rescale(cs_rank((((close - low) - (high - close)) / (high - low)) * volume))) -" " cs_rescale(cs_rank(ts_argmax(close, 10)))))" ), 61: "cs_rank(vwap - ts_min(vwap, 16.1219)) < cs_rank(ts_corr(vwap, adv(180), 17.9282))", 62: ( "(cs_rank(ts_corr(vwap, ts_sum(adv(20), 22.4101), 9.91009)) < cs_rank((cs_rank(open) + cs_rank(open)) <" " (cs_rank((high + low) / 2) + cs_rank(high)))) * (-1)" ), 63: ( "(cs_rank(ts_wmean(delta(cs_demean(close, 'industry'), 2.25164), 8.22237)) - cs_rank(ts_wmean(ts_corr((vwap" " * 0.318108) + (open * (1 - 0.318108)), ts_sum(adv(180), 37.2467), 13.557), 12.2883))) * (-1)" ), 64: ( "(cs_rank(ts_corr(ts_sum((open * 0.178404) + (low * (1 - 0.178404)), 12.7054), ts_sum(adv(120), 12.7054)," " 16.6208)) < cs_rank(delta((((high + low) / 2) * 0.178404) + (vwap * (1 - 0.178404)), 3.69741))) * (-1)" ), 65: ( "(cs_rank(ts_corr((open * 0.00817205) + (vwap * (1 - 0.00817205)), ts_sum(adv(60), 8.6911), 6.40374)) <" " cs_rank(open - ts_min(open, 13.635))) * (-1)" ), 66: ( "(cs_rank(ts_wmean(delta(vwap, 3.51013), 7.23052)) + ts_rank(ts_wmean((((low * 0.96633) + (low * (1 -" " 0.96633))) - vwap) / (open - ((high + low) / 2)), 11.4157), 6.72611)) * (-1)" ), 67: ( "(cs_rank(high - ts_min(high, 2.14593)) ** cs_rank(ts_corr(cs_demean(vwap, 'sector'), cs_demean(adv(20)," " 'subindustry'), 6.02936))) * (-1)" ), 68: ( "(ts_rank(ts_corr(cs_rank(high), cs_rank(adv(15)), 8.91644), 13.9333) < cs_rank(delta((close * 0.518371) +" " (low * (1 - 0.518371)), 1.06157))) * (-1)" ), 69: ( "(cs_rank(ts_max(delta(cs_demean(vwap, 'industry'), 2.72412), 4.79344)) ** ts_rank(ts_corr((close *" " 0.490655) + (vwap * (1 - 0.490655)), adv(20), 4.92416), 9.0615)) * (-1)" ), 70: ( "(cs_rank(delta(vwap, 1.29456)) ** ts_rank(ts_corr(cs_demean(close, 'industry'), adv(50), 17.8256)," " 17.9171)) * (-1)" ), 71: ( "maximum(ts_rank(ts_wmean(ts_corr(ts_rank(close, 3.43976), ts_rank(adv(180), 12.0647), 18.0175), 4.20501)," " 15.6948), ts_rank(ts_wmean(cs_rank((low + open) - (vwap + vwap)) ** 2, 16.4662), 4.4388))" ), 72: ( "cs_rank(ts_wmean(ts_corr((high + low) / 2, adv(40), 8.93345), 10.1519)) /" " cs_rank(ts_wmean(ts_corr(ts_rank(vwap, 3.72469), ts_rank(volume, 18.5188), 6.86671), 2.95011))" ), 73: ( "maximum(cs_rank(ts_wmean(delta(vwap, 4.72775), 2.91864)), ts_rank(ts_wmean((delta((open * 0.147155) + (low" " * (1 - 0.147155)), 2.03608) / ((open * 0.147155) + (low * (1 - 0.147155)))) * (-1), 3.33829), 16.7411)) *" " (-1)" ), 74: ( "(cs_rank(ts_corr(close, ts_sum(adv(30), 37.4843), 15.1365)) < cs_rank(ts_corr(cs_rank((high * 0.0261661) +" " (vwap * (1 - 0.0261661))), cs_rank(volume), 11.4791))) * (-1)" ), 75: "cs_rank(ts_corr(vwap, volume, 4.24304)) < cs_rank(ts_corr(cs_rank(low), cs_rank(adv(50)), 12.4413))", 76: ( "maximum(cs_rank(ts_wmean(delta(vwap, 1.24383), 11.8259)), ts_rank(ts_wmean(ts_rank(ts_corr(cs_demean(low," " 'sector'), adv(81), 8.14941), 19.569), 17.1543), 19.383)) * (-1)" ), 77: ( "minimum(cs_rank(ts_wmean((((high + low) / 2) + high) - (vwap + high), 20.0451))," " cs_rank(ts_wmean(ts_corr((high + low) / 2, adv(40), 3.1614), 5.64125)))" ), 78: ( "cs_rank(ts_corr(ts_sum((low * 0.352233) + (vwap * (1 - 0.352233)), 19.7428), ts_sum(adv(40), 19.7428)," " 6.83313)) ** cs_rank(ts_corr(cs_rank(vwap), cs_rank(volume), 5.77492))" ), 79: ( "cs_rank(delta(cs_demean((close * 0.60733) + (open * (1 - 0.60733)), 'sector'), 1.23438)) <" " cs_rank(ts_corr(ts_rank(vwap, 3.60973), ts_rank(adv(150), 9.18637), 14.6644))" ), 80: ( "(cs_rank(sign(delta(cs_demean((open * 0.868128) + (high * (1 - 0.868128)), 'industry'), 4.04545))) **" " ts_rank(ts_corr(high, adv(10), 5.11456), 5.53756)) * (-1)" ), 81: ( "(cs_rank(log(ts_product(cs_rank(cs_rank(ts_corr(vwap, ts_sum(adv(10), 49.6054), 8.47743)) ** 4)," " 14.9655))) < cs_rank(ts_corr(cs_rank(vwap), cs_rank(volume), 5.07914))) * (-1)" ), 82: ( "minimum(cs_rank(ts_wmean(delta(open, 1.46063), 14.8717)), ts_rank(ts_wmean(ts_corr(cs_demean(volume," " 'sector'), ((open * 0.634196) + (open * (1 - 0.634196))), 17.4842), 6.92131), 13.4283)) * (-1)" ), 83: ( "(cs_rank(delay((high - low) / (ts_sum(close, 5) / 5), 2)) * cs_rank(cs_rank(volume))) / (((high - low) /" " (ts_sum(close, 5) / 5)) / (vwap - close))" ), 84: "power(ts_rank(vwap - ts_max(vwap, 15.3217), 20.7127), delta(close, 4.96796))", 85: ( "cs_rank(ts_corr((high * 0.876703) + (close * (1 - 0.876703)), adv(30), 9.61331)) **" " cs_rank(ts_corr(ts_rank((high + low) / 2, 3.70596), ts_rank(volume, 10.1595), 7.11408))" ), 86: ( "(ts_rank(ts_corr(close, ts_sum(adv(20), 14.7444), 6.00049), 20.4195) < cs_rank((open + close) - (vwap +" " open))) * (-1)" ), 87: ( "maximum(cs_rank(ts_wmean(delta((close * 0.369701) + (vwap * (1 - 0.369701)), 1.91233), 2.65461))," " ts_rank(ts_wmean(abs(ts_corr(cs_demean(adv(81), 'industry'), close, 13.4132)), 4.89768), 14.4535)) * (-1)" ), 88: ( "minimum(cs_rank(ts_wmean((cs_rank(open) + cs_rank(low)) - (cs_rank(high) + cs_rank(close)), 8.06882))," " ts_rank(ts_wmean(ts_corr(ts_rank(close, 8.44728), ts_rank(adv(60), 20.6966), 8.01266), 6.65053)," " 2.61957))" ), 89: ( "ts_rank(ts_wmean(ts_corr((low * 0.967285) + (low * (1 - 0.967285)), adv(10), 6.94279), 5.51607), 3.79744)" " - ts_rank(ts_wmean(delta(cs_demean(vwap, 'industry'), 3.48158), 10.1466), 15.3012)" ), 90: ( "(cs_rank(close - ts_max(close, 4.66719)) ** ts_rank(ts_corr(cs_demean(adv(40), 'subindustry'), low," " 5.38375), 3.21856)) * (-1)" ), 91: ( "(ts_rank(ts_wmean(ts_wmean(ts_corr(cs_demean(close, 'industry'), volume, 9.74928), 16.398), 3.83219)," " 4.8667) - cs_rank(ts_wmean(ts_corr(vwap, adv(30), 4.01303), 2.6809))) * (-1)" ), 92: ( "minimum(ts_rank(ts_wmean((((high + low) / 2) + close) < (low + open), 14.7221), 18.8683)," " ts_rank(ts_wmean(ts_corr(cs_rank(low), cs_rank(adv(30)), 7.58555), 6.94024), 6.80584))" ), 93: ( "ts_rank(ts_wmean(ts_corr(cs_demean(vwap, 'industry'), adv(81), 17.4193), 19.848), 7.54455) /" " cs_rank(ts_wmean(delta((close * 0.524434) + (vwap * (1 - 0.524434)), 2.77377), 16.2664))" ), 94: ( "(cs_rank(vwap - ts_min(vwap, 11.5783)) ** ts_rank(ts_corr(ts_rank(vwap, 19.6462), ts_rank(adv(60)," " 4.02992), 18.0926), 2.70756)) * (-1)" ), 95: ( "cs_rank(open - ts_min(open, 12.4105)) < ts_rank(cs_rank(ts_corr(ts_sum((high + low) / 2, 19.1351)," " ts_sum(adv(40), 19.1351), 12.8742)) ** 5, 11.7584)" ), 96: ( "maximum(ts_rank(ts_wmean(ts_corr(cs_rank(vwap), cs_rank(volume), 3.83878), 4.16783), 8.38151)," " ts_rank(ts_wmean(ts_argmax(ts_corr(ts_rank(close, 7.45404), ts_rank(adv(60), 4.13242), 3.65459)," " 12.6556), 14.0365), 13.4143)) * (-1)" ), 97: ( "(cs_rank(ts_wmean(delta(cs_demean((low * 0.721001) + (vwap * (1 - 0.721001)), 'industry'), 3.3705)," " 20.4523)) - ts_rank(ts_wmean(ts_rank(ts_corr(ts_rank(low, 7.87871), ts_rank(adv(60), 17.255), 4.97547)," " 18.5925), 15.7152), 6.71659)) * (-1)" ), 98: ( "cs_rank(ts_wmean(ts_corr(vwap, ts_sum(adv(5), 26.4719), 4.58418), 7.18088)) -" " cs_rank(ts_wmean(ts_rank(ts_argmin(ts_corr(cs_rank(open), cs_rank(adv(15)), 20.8187), 8.62571), 6.95668)," " 8.07206))" ), 99: ( "(cs_rank(ts_corr(ts_sum((high + low) / 2, 19.8975), ts_sum(adv(60), 19.8975), 8.8136)) <" " cs_rank(ts_corr(low, volume, 6.28259))) * (-1)" ), 100: ( "0 - (1 * (((1.5 * cs_rescale(cs_demean(cs_demean(cs_rank((((close - low) - (high - close)) / (high - low))" " * volume), 'subindustry'), 'subindustry'))) - cs_rescale(cs_demean(ts_corr(close, cs_rank(adv(20)), 5) -" " cs_rank(ts_argmin(close, 30)), 'subindustry'))) * (volume / adv(20))))" ), 101: "(close - open) / ((high - low) + .001)", } ) """_""" __pdoc__[ "wqa101_expr_config" ] = f"""Config with WorldQuant's 101 alpha expressions. See [101 Formulaic Alphas](https://arxiv.org/abs/1601.00991). Can be modified. ```python {wqa101_expr_config.prettify()} ``` """ # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Factory for building indicators. Run for the examples below: ```pycon >>> from vectorbtpro import * >>> price = pd.DataFrame({ ... 'a': [1, 2, 3, 4, 5], ... 'b': [5, 4, 3, 2, 1] ... }, index=pd.date_range("2020", periods=5)).astype(float) >>> price a b 2020-01-01 1.0 5.0 2020-01-02 2.0 4.0 2020-01-03 3.0 3.0 2020-01-04 4.0 2.0 2020-01-05 5.0 1.0 ```""" import fnmatch import functools import inspect import itertools import re from collections import Counter, OrderedDict from types import ModuleType, FunctionType import numpy as np import pandas as pd from numba import njit from numba.typed import List from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import indexes, reshaping, combining from vectorbtpro.base.indexing import build_param_indexer from vectorbtpro.base.merging import row_stack_arrays, column_stack_arrays from vectorbtpro.base.reshaping import broadcast_array_to, Default, resolve_ref from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.accessors import BaseAccessor from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.indicators.expr import expr_func_config, expr_res_func_config, wqa101_expr_config from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils.array_ import build_nan_mask, squeeze_nan, unsqueeze_nan from vectorbtpro.utils.config import merge_dicts, resolve_dict, Config, Configured, HybridConfig from vectorbtpro.utils.decorators import class_property, cacheable_property, hybrid_method from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.eval_ import evaluate from vectorbtpro.utils.execution import Task from vectorbtpro.utils.formatting import camel_to_snake_case, prettify from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods from vectorbtpro.utils.mapping import to_value_mapping, apply_mapping from vectorbtpro.utils.module_ import search_package from vectorbtpro.utils.params import ( to_typed_list, broadcast_params, create_param_product, is_single_param_value, params_to_list, ) from vectorbtpro.utils.parsing import get_expr_var_names, get_func_arg_names, get_func_kwargs, suppress_stdout from vectorbtpro.utils.random_ import set_seed from vectorbtpro.utils.template import has_templates, substitute_templates, Rep from vectorbtpro.utils.warnings_ import warn, WarningsFiltered __all__ = [ "IndicatorBase", "IndicatorFactory", "IF", "indicator", "talib", "pandas_ta", "ta", "wqa101", "technical", "techcon", "smc", ] __pdoc__ = {} try: if not tp.TYPE_CHECKING: raise ImportError from ta.utils import IndicatorMixin as IndicatorMixinT except ImportError: IndicatorMixinT = "IndicatorMixin" try: if not tp.TYPE_CHECKING: raise ImportError from technical.consensus import Consensus as ConsensusT except ImportError: ConsensusT = "Consensus" def prepare_params( params: tp.MaybeParams, param_names: tp.Sequence[str], param_settings: tp.Sequence[tp.KwargsLike], input_shape: tp.Optional[tp.Shape] = None, to_2d: bool = False, context: tp.KwargsLike = None, ) -> tp.Tuple[tp.Params, bool]: """Prepare parameters. Resolves references and performs broadcasting to the input shape. Returns prepared parameters as well as whether the user provided a single parameter combination.""" # Resolve references if context is None: context = {} pool = dict(zip(param_names, params)) for k in pool: pool[k] = resolve_ref(pool, k) params = [pool[k] for k in param_names] new_params = [] single_comb = True for i, param_values in enumerate(params): # Resolve settings _param_settings = resolve_dict(param_settings[i]) is_tuple = _param_settings.get("is_tuple", False) dtype = _param_settings.get("dtype", None) if checks.is_mapping_like(dtype): dtype_kwargs = _param_settings.get("dtype_kwargs", None) if dtype_kwargs is None: dtype_kwargs = {} if checks.is_namedtuple(dtype): param_values = map_enum_fields(param_values, dtype, **dtype_kwargs) else: param_values = apply_mapping(param_values, dtype, **dtype_kwargs) is_array_like = _param_settings.get("is_array_like", False) min_one_dim = _param_settings.get("min_one_dim", False) bc_to_input = _param_settings.get("bc_to_input", False) broadcast_kwargs = merge_dicts( dict(require_kwargs=dict(requirements="W")), _param_settings.get("broadcast_kwargs", None), ) template = _param_settings.get("template", None) if not is_single_param_value(param_values, is_tuple, is_array_like): single_comb = False new_param_values = params_to_list(param_values, is_tuple, is_array_like) if template is not None: new_param_values = [ template.substitute(context={param_names[i]: new_param_values[j], **context}) for j in range(len(new_param_values)) ] if not bc_to_input: if is_array_like: if min_one_dim: new_param_values = list(map(reshaping.to_1d_array, new_param_values)) else: new_param_values = list(map(np.asarray, new_param_values)) else: # Broadcast to input or its axis if is_tuple: raise ValueError("Cannot broadcast to input if tuple") if input_shape is None: raise ValueError("Cannot broadcast to input if input shape is unknown. Pass input_shape.") if bc_to_input is True: to_shape = input_shape else: checks.assert_in(bc_to_input, (0, 1)) # Note that input_shape can be 1D if bc_to_input == 0: to_shape = (input_shape[0],) else: to_shape = (input_shape[1],) if len(input_shape) > 1 else (1,) _new_param_values = reshaping.broadcast(*new_param_values, to_shape=to_shape, **broadcast_kwargs) if len(new_param_values) == 1: _new_param_values = [_new_param_values] else: _new_param_values = list(_new_param_values) if to_2d and bc_to_input is True: # If inputs are meant to reshape to 2D, do the same to parameters # But only to those that fully resemble inputs (= not raw) __new_param_values = _new_param_values.copy() for j, param in enumerate(__new_param_values): keep_flex = broadcast_kwargs.get("keep_flex", False) if keep_flex is False or (isinstance(keep_flex, (tuple, list)) and not keep_flex[j]): __new_param_values[j] = reshaping.to_2d(param) new_param_values = __new_param_values else: new_param_values = _new_param_values new_params.append(new_param_values) return new_params, single_comb def build_columns( params: tp.Params, input_columns: tp.IndexLike, level_names: tp.Optional[tp.Sequence[str]] = None, hide_levels: tp.Optional[tp.Sequence[tp.Union[str, int]]] = None, single_value: tp.Optional[tp.Sequence[bool]] = None, param_settings: tp.KwargsLikeSequence = None, per_column: bool = False, ignore_ranges: bool = False, **kwargs, ) -> dict: """For each parameter in `params`, create a new column level with parameter values and stack it on top of `input_columns`.""" if level_names is not None: checks.assert_len_equal(params, level_names) if hide_levels is None: hide_levels = [] input_columns = indexes.to_any_index(input_columns) param_indexes = [] rep_param_indexes = [] vis_param_indexes = [] vis_rep_param_indexes = [] has_per_column = False for i in range(len(params)): param_values = params[i] level_name = None if level_names is not None: level_name = level_names[i] _single_value = False if single_value is not None: _single_value = single_value[i] _param_settings = resolve_dict(param_settings, i=i) dtype = _param_settings.get("dtype", None) if checks.is_mapping_like(dtype): if checks.is_namedtuple(dtype): dtype = to_value_mapping(dtype, reverse=False) else: dtype = to_value_mapping(dtype, reverse=True) param_values = apply_mapping(param_values, dtype) _per_column = _param_settings.get("per_column", False) _post_index_func = _param_settings.get("post_index_func", None) if per_column: param_index = indexes.index_from_values(param_values, single_value=_single_value, name=level_name) repeat_index = False has_per_column = True elif _per_column: param_index = None for p in param_values: bc_param = broadcast_array_to(p, len(input_columns)) _param_index = indexes.index_from_values(bc_param, single_value=False, name=level_name) if param_index is None: param_index = _param_index else: param_index = param_index.append(_param_index) if len(param_index) == 1 and len(input_columns) > 1: param_index = indexes.repeat_index(param_index, len(input_columns), ignore_ranges=ignore_ranges) repeat_index = False has_per_column = True else: param_index = indexes.index_from_values(param_values, single_value=_single_value, name=level_name) repeat_index = True if _post_index_func is not None: param_index = _post_index_func(param_index) if repeat_index: rep_param_index = indexes.repeat_index(param_index, len(input_columns), ignore_ranges=ignore_ranges) else: rep_param_index = param_index param_indexes.append(param_index) rep_param_indexes.append(rep_param_index) if i not in hide_levels and (level_names is None or level_names[i] not in hide_levels): vis_param_indexes.append(param_index) vis_rep_param_indexes.append(rep_param_index) if not per_column: n_param_values = len(params[0]) if len(params) > 0 else 1 input_columns = indexes.tile_index(input_columns, n_param_values, ignore_ranges=ignore_ranges) if len(vis_param_indexes) > 0: if has_per_column: param_index = None else: param_index = indexes.stack_indexes(vis_param_indexes, **kwargs) final_index = indexes.stack_indexes([*vis_rep_param_indexes, input_columns], **kwargs) else: param_index = None final_index = input_columns return dict( param_indexes=rep_param_indexes, rep_param_indexes=rep_param_indexes, vis_param_indexes=vis_param_indexes, vis_rep_param_indexes=vis_rep_param_indexes, param_index=param_index, final_index=final_index, ) def combine_objs( obj: tp.SeriesFrame, other: tp.MaybeTupleList[tp.Union[tp.ArrayLike, BaseAccessor]], combine_func: tp.Callable, *args, level_name: tp.Optional[str] = None, keys: tp.Optional[tp.IndexLike] = None, allow_multiple: bool = True, **kwargs, ) -> tp.SeriesFrame: """Combines/compares `obj` to `other`, for example, to generate signals. Both will broadcast together. Pass `other` as a tuple or a list to compare with multiple arguments. In this case, a new column level will be created with the name `level_name`. See `vectorbtpro.base.accessors.BaseAccessor.combine`.""" if allow_multiple and isinstance(other, (tuple, list)): if keys is None: keys = indexes.index_from_values(other, name=level_name) return obj.vbt.combine(other, combine_func, *args, keys=keys, allow_multiple=allow_multiple, **kwargs) IndicatorBaseT = tp.TypeVar("IndicatorBaseT", bound="IndicatorBase") CacheOutputT = tp.Any RawOutputT = tp.Tuple[ tp.List[tp.Array2d], tp.List[tp.Tuple[tp.ParamValue, ...]], int, tp.List[tp.Any], ] InputListT = tp.List[tp.Array2d] InputMapperT = tp.Optional[tp.Array1d] InOutputListT = tp.List[tp.Array2d] OutputListT = tp.List[tp.Array2d] ParamListT = tp.List[tp.List[tp.ParamValue]] MapperListT = tp.List[tp.Index] OtherListT = tp.List[tp.Any] PipelineOutputT = tp.Tuple[ ArrayWrapper, InputListT, InputMapperT, InOutputListT, OutputListT, ParamListT, MapperListT, OtherListT, ] RunOutputT = tp.Union[IndicatorBaseT, tp.Tuple[tp.Any, ...], RawOutputT, CacheOutputT] RunCombsOutputT = tp.Tuple[IndicatorBaseT, ...] def combine_indicator_with_other( self: IndicatorBaseT, other: tp.Union["IndicatorBase", tp.ArrayLike], np_func: tp.Callable[[tp.ArrayLike, tp.ArrayLike], tp.Array1d], ) -> tp.SeriesFrame: """Combine `IndicatorBase` with other compatible object.""" if isinstance(other, IndicatorBase): other = other.main_output return np_func(self.main_output, other) @attach_binary_magic_methods(combine_indicator_with_other) @attach_unary_magic_methods(lambda self, np_func: np_func(self.main_output)) class IndicatorBase(Analyzable): """Indicator base class. Properties should be set before instantiation.""" _short_name: tp.ClassVar[str] _input_names: tp.ClassVar[tp.Tuple[str, ...]] _param_names: tp.ClassVar[tp.Tuple[str, ...]] _in_output_names: tp.ClassVar[tp.Tuple[str, ...]] _output_names: tp.ClassVar[tp.Tuple[str, ...]] _lazy_output_names: tp.ClassVar[tp.Tuple[str, ...]] _output_flags: tp.ClassVar[tp.Kwargs] def __getattr__(self, k: str) -> tp.Any: """Redirect queries targeted at a generic output name by "output" or the short name of the indicator.""" try: return object.__getattribute__(self, k) except AttributeError: pass if k == "vbt": return self.main_output.vbt short_name = object.__getattribute__(self, "short_name") output_names = object.__getattribute__(self, "output_names") if len(output_names) == 1: if k.startswith("output") and "output" not in output_names: new_k = k[len("output") :] if len(new_k) == 0 or not new_k[0].isalnum(): try: return object.__getattribute__(self, output_names[0] + new_k) except AttributeError: pass if k.startswith(short_name) and short_name not in output_names: new_k = k[len(short_name) :] if len(new_k) == 0 or not new_k[0].isalnum(): try: return object.__getattribute__(self, output_names[0] + new_k) except AttributeError: pass if k.lower().startswith(short_name.lower()) and short_name.lower() not in output_names: new_k = k[len(short_name) :].lower() if len(new_k) == 0 or not new_k[0].isalnum(): try: return object.__getattribute__(self, output_names[0] + new_k) except AttributeError: pass try: return object.__getattribute__(self, output_names[0] + "_" + k) except AttributeError: pass elif short_name in output_names: try: return object.__getattribute__(self, short_name + "_" + k) except AttributeError: pass elif short_name.lower() in output_names: try: return object.__getattribute__(self, short_name.lower() + "_" + k) except AttributeError: pass return object.__getattribute__(self, k) @property def main_output(self) -> tp.SeriesFrame: """Get main output. It's either the only output or an output that matches the short name of the indicator.""" if len(self.output_names) == 1: return getattr(self, self.output_names[0]) if self.short_name in self.output_names: return getattr(self, self.short_name) if self.short_name.lower() in self.output_names: return getattr(self, self.short_name.lower()) raise ValueError(f"Indicator {self} has no main output") def __array__(self, dtype: tp.Optional[tp.DTypeLike] = None) -> tp.Array: """Convert main output to NumPy array.""" return np.asarray(self.main_output, dtype=dtype) @classmethod def run_pipeline( cls, num_ret_outputs: int, custom_func: tp.Callable, *args, require_input_shape: bool = False, input_shape: tp.Optional[tp.ShapeLike] = None, input_index: tp.Optional[tp.IndexLike] = None, input_columns: tp.Optional[tp.IndexLike] = None, inputs: tp.Optional[tp.MappingSequence[tp.ArrayLike]] = None, in_outputs: tp.Optional[tp.MappingSequence[tp.ArrayLike]] = None, in_output_settings: tp.Optional[tp.MappingSequence[tp.KwargsLike]] = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, params: tp.Optional[tp.MaybeParams] = None, param_product: bool = False, random_subset: tp.Optional[int] = None, param_settings: tp.Optional[tp.MappingSequence[tp.KwargsLike]] = None, run_unique: bool = False, silence_warnings: bool = False, per_column: tp.Optional[bool] = None, keep_pd: bool = False, to_2d: bool = True, pass_packed: bool = False, pass_input_shape: tp.Optional[bool] = None, pass_wrapper: bool = False, pass_param_index: bool = False, pass_final_index: bool = False, pass_single_comb: bool = False, level_names: tp.Optional[tp.Sequence[str]] = None, hide_levels: tp.Optional[tp.Sequence[tp.Union[str, int]]] = None, build_col_kwargs: tp.KwargsLike = None, return_raw: tp.Union[bool, str] = False, use_raw: tp.Optional[RawOutputT] = None, wrapper_kwargs: tp.KwargsLike = None, seed: tp.Optional[int] = None, **kwargs, ) -> tp.Union[CacheOutputT, RawOutputT, PipelineOutputT]: """A pipeline for running an indicator, used by `IndicatorFactory`. Args: num_ret_outputs (int): The number of output arrays returned by `custom_func`. custom_func (callable): A custom calculation function. See `IndicatorFactory.with_custom_func`. *args: Arguments passed to the `custom_func`. require_input_shape (bool): Whether to input shape is required. Will set `pass_input_shape` to True and raise an error if `input_shape` is None. input_shape (tuple): Shape to broadcast each input to. Can be passed to `custom_func`. See `pass_input_shape`. input_index (index_like): Sets index of each input. Can be used to label index if no inputs passed. input_columns (index_like): Sets columns of each input. Can be used to label columns if no inputs passed. inputs (mapping or sequence of array_like): A mapping or sequence of input arrays. Use mapping to also supply names. If sequence, will convert to a mapping using `input_{i}` key. in_outputs (mapping or sequence of array_like): A mapping or sequence of in-place output arrays. Use mapping to also supply names. If sequence, will convert to a mapping using `in_output_{i}` key. in_output_settings (dict or sequence of dict): Settings corresponding to each in-place output. If mapping, should contain keys from `in_outputs`. Following keys are accepted: * `dtype`: Create this array using this data type and `np.empty`. Default is None. broadcast_named_args (dict): Dictionary with named arguments to broadcast together with inputs. You can then pass argument names wrapped with `vectorbtpro.utils.template.Rep` and this method will substitute them by their corresponding broadcasted objects. broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast` to broadcast inputs. template_context (dict): Context used to substitute templates in `args` and `kwargs`. params (mapping or sequence of any): A mapping or sequence of parameters. Use mapping to also supply names. If sequence, will convert to a mapping using `param_{i}` key. Each element is either an array-like object or a single value of any type. param_product (bool): Whether to build a Cartesian product out of all parameters. random_subset (int): Number of parameter combinations to pick randomly. param_settings (dict or sequence of dict): Settings corresponding to each parameter. If mapping, should contain keys from `params`. Following keys are accepted: * `dtype`: If data type is an enumerated type or other mapping, and a string as parameter value was passed, will convert it first. * `dtype_kwargs`: Keyword arguments passed to the function processing the data type. If data type is enumerated, it will be `vectorbtpro.utils.enum_.map_enum_fields`. * `is_tuple`: If tuple was passed, it will be considered as a single value. To treat it as multiple values, pack it into a list. * `is_array_like`: If array-like object was passed, it will be considered as a single value. To treat it as multiple values, pack it into a list. * `template`: Template to substitute each parameter value with, before broadcasting to input. * `min_one_dim`: Whether to convert any scalar into a one-dimensional array. Works only if `bc_to_input` is False. * `bc_to_input`: Whether to broadcast parameter to input size. You can also broadcast parameter to an axis by passing an integer. * `broadcast_kwargs`: Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`. * `per_column`: Whether each parameter value can be split by columns such that it can be better reflected in a multi-index. Does not affect broadcasting. * `post_index_func`: Function to convert the final index level of the parameter. Defaults to None. run_unique (bool): Whether to run only on unique parameter combinations. Disable if two identical parameter combinations can lead to different results (e.g., due to randomness) or if inputs are large and `custom_func` is fast. !!! note Cache, raw output, and output objects outside of `num_ret_outputs` will be returned for unique parameter combinations only. silence_warnings (bool): Whether to hide warnings such as coming from `run_unique`. per_column (bool): Whether the values of each parameter should be split by columns. Defaults to False. Will pass `per_column` if it's not None. Each list of parameter values will broadcast to the number of columns and each parameter value will be applied per column rather than per whole input. Input shape must be known beforehand. Each from inputs, in-outputs, and parameters will be passed to `custom_func` with the full shape. Expects the outputs be of the same shape as inputs. keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays. to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is. pass_packed (bool): Whether to pass inputs and parameters to `custom_func` as lists. If `custom_func` is Numba-compiled, passes tuples. pass_input_shape (bool): Whether to pass `input_shape` to `custom_func` as keyword argument. Defaults to True if `require_input_shape` is True, otherwise to False. pass_wrapper (bool): Whether to pass the input wrapper to `custom_func` as keyword argument. pass_param_index (bool): Whether to pass parameter index. pass_final_index (bool): Whether to pass final index. pass_single_comb (bool): Whether to pass whether there is only one parameter combination. level_names (list of str): A list of column level names corresponding to each parameter. Must have the same length as `params`. hide_levels (list of int or str): A list of level names or indices of parameter levels to hide. build_col_kwargs (dict): Keyword arguments passed to `build_columns`. return_raw (bool or str): Whether to return raw outputs and hashed parameter tuples without further post-processing. Pass "outputs" to only return outputs. use_raw (bool): Takes the raw results and uses them instead of running `custom_func`. wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`. seed (int): Seed to make output deterministic. **kwargs: Keyword arguments passed to the `custom_func`. Some common arguments include `return_cache` to return cache and `use_cache` to use cache. If `use_cache` is False, disables caching completely. Those are only applicable to `custom_func` that supports it (`custom_func` created using `IndicatorFactory.with_apply_func` are supported by default). Returns: Array wrapper, list of inputs (`np.ndarray`), input mapper (`np.ndarray`), list of outputs (`np.ndarray`), list of parameter arrays (`np.ndarray`), list of parameter mappers (`np.ndarray`), list of outputs that are outside of `num_ret_outputs`. """ pass_per_column = per_column is not None if per_column is None: per_column = False if len(params) == 0 and per_column: raise ValueError("per_column cannot be enabled without parameters") if require_input_shape: checks.assert_not_none(input_shape, arg_name="input_shape") if pass_input_shape is None: pass_input_shape = True if pass_input_shape is None: pass_input_shape = False if input_index is not None: input_index = indexes.to_any_index(input_index) if input_columns is not None: input_columns = indexes.to_any_index(input_columns) if inputs is None: inputs = {} if not checks.is_mapping(inputs): inputs = {"input_" + str(i): input for i, input in enumerate(inputs)} input_names = list(inputs.keys()) input_list = list(inputs.values()) if in_outputs is None: in_outputs = {} if not checks.is_mapping(in_outputs): in_outputs = {"in_output_" + str(i): in_output for i, in_output in enumerate(in_outputs)} in_output_names = list(in_outputs.keys()) in_output_list = list(in_outputs.values()) if in_output_settings is None: in_output_settings = {} if checks.is_mapping(in_output_settings): checks.assert_dict_valid(in_output_settings, [in_output_names, "dtype"]) in_output_settings = [in_output_settings.get(k, None) for k in in_output_names] if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if params is None: params = {} if not checks.is_mapping(params): params = {"param_" + str(i): param for i, param in enumerate(params)} param_names = list(params.keys()) param_list = list(params.values()) if param_settings is None: param_settings = {} if checks.is_mapping(param_settings): checks.assert_dict_valid( param_settings, [ param_names, [ "dtype", "dtype_kwargs", "is_tuple", "is_array_like", "template", "min_one_dim", "bc_to_input", "broadcast_kwargs", "per_column", "post_index_func", ], ], ) param_settings = [param_settings.get(k, None) for k in param_names] if hide_levels is None: hide_levels = [] if build_col_kwargs is None: build_col_kwargs = {} if wrapper_kwargs is None: wrapper_kwargs = {} if keep_pd and checks.is_numba_func(custom_func): raise ValueError("Cannot pass pandas objects to a Numba-compiled custom_func. Set keep_pd to False.") # Set seed if seed is not None: set_seed(seed) if input_shape is not None: input_shape = reshaping.to_tuple_shape(input_shape) if len(inputs) > 0 or len(in_outputs) > 0 or len(broadcast_named_args) > 0: # Broadcast inputs, in-outputs, and named args # If input_shape is provided, will broadcast all inputs to this shape broadcast_args = merge_dicts(inputs, in_outputs, broadcast_named_args) broadcast_kwargs = merge_dicts( dict( to_shape=input_shape, index_from=input_index, columns_from=input_columns, require_kwargs=dict(requirements="W"), post_func=None if keep_pd else np.asarray, to_pd=True, ), broadcast_kwargs, ) broadcast_args, wrapper = reshaping.broadcast(broadcast_args, return_wrapper=True, **broadcast_kwargs) input_shape, input_index, input_columns = wrapper.shape, wrapper.index, wrapper.columns if input_index is None: input_index = pd.RangeIndex(start=0, step=1, stop=input_shape[0]) if input_columns is None: input_columns = pd.RangeIndex(start=0, step=1, stop=input_shape[1] if len(input_shape) > 1 else 1) input_list = [broadcast_args[input_name] for input_name in input_names] in_output_list = [broadcast_args[in_output_name] for in_output_name in in_output_names] broadcast_named_args = {arg_name: broadcast_args[arg_name] for arg_name in broadcast_named_args} else: wrapper = None # Reshape input shape input_shape_ready = input_shape input_shape_2d = input_shape if input_shape is not None: input_shape_2d = input_shape if len(input_shape) > 1 else (input_shape[0], 1) if to_2d: if input_shape is not None: input_shape_ready = input_shape_2d # ready for custom_func if wrapper is not None: wrapper_ready = wrapper elif input_index is not None and input_columns is not None and input_shape_ready is not None: wrapper_ready = ArrayWrapper(input_index, input_columns, len(input_shape_ready)) else: wrapper_ready = None # Prepare inputs input_list_ready = [] for input in input_list: new_input = input if to_2d: new_input = reshaping.to_2d(input) if keep_pd and isinstance(new_input, np.ndarray): # Keep as pandas object new_input = ArrayWrapper(input_index, input_columns, new_input.ndim).wrap(new_input) input_list_ready.append(new_input) # Prepare parameters # NOTE: input_shape instead of input_shape_ready since parameters should # broadcast by the same rules as inputs param_context = merge_dicts( broadcast_named_args, dict( input_shape=input_shape_ready, wrapper=wrapper_ready, **dict(zip(input_names, input_list_ready)), pre_sub_args=args, pre_sub_kwargs=kwargs, ), template_context, ) param_list, single_comb = prepare_params( param_list, param_names, param_settings, input_shape=input_shape, to_2d=to_2d, context=param_context, ) single_value = list(map(lambda x: len(x) == 1, param_list)) if len(param_list) > 1: if level_names is not None: # Check level names checks.assert_len_equal(param_list, level_names) # Columns should be free of the specified level names if input_columns is not None: for level_name in level_names: if level_name is not None: checks.assert_level_not_exists(input_columns, level_name) if param_product: # Make Cartesian product out of all params param_list = create_param_product(param_list) if len(param_list) > 0: # Broadcast such that each array has the same length if per_column: # The number of parameters should match the number of columns before split param_list = broadcast_params(param_list, to_n=input_shape_2d[1]) else: param_list = broadcast_params(param_list) if random_subset is not None: # Pick combinations randomly if per_column: raise ValueError("Cannot select random subset when per_column=True") random_indices = np.sort(np.random.permutation(np.arange(len(param_list[0])))[:random_subset]) param_list = [[params[i] for i in random_indices] for params in param_list] n_param_values = len(param_list[0]) if len(param_list) > 0 else 1 use_run_unique = False param_list_unique = param_list if not per_column and run_unique: try: # Try to get all unique parameter combinations param_tuples = list(zip(*param_list)) unique_param_tuples = list(OrderedDict.fromkeys(param_tuples).keys()) if len(unique_param_tuples) < len(param_tuples): param_list_unique = list(map(list, zip(*unique_param_tuples))) use_run_unique = True except: pass if checks.is_numba_func(custom_func): # Numba can't stand untyped lists param_list_ready = [to_typed_list(params) for params in param_list_unique] else: param_list_ready = param_list_unique n_unique_param_values = len(param_list_unique[0]) if len(param_list_unique) > 0 else 1 # Build column hierarchy for execution if len(param_list) > 0 and input_columns is not None and (pass_param_index or pass_final_index): # Build new column levels on top of input levels build_columns_meta = build_columns( param_list, input_columns, level_names=level_names, hide_levels=hide_levels, single_value=single_value, param_settings=param_settings, per_column=per_column, **build_col_kwargs, ) rep_param_indexes = build_columns_meta["rep_param_indexes"] param_index = build_columns_meta["param_index"] final_index = build_columns_meta["final_index"] else: # Some indicators don't have any params rep_param_indexes = None param_index = None final_index = None # Prepare in-place outputs in_output_list_ready = [] for i in range(len(in_output_list)): if input_shape_2d is None: raise ValueError("input_shape is required when using in-place outputs") if in_output_list[i] is not None: # This in-place output has been already broadcast with inputs in_output_wide = in_output_list[i] if isinstance(in_output_list[i], np.ndarray): in_output_wide = np.require(in_output_wide, requirements="W") if not per_column: # One per parameter combination in_output_wide = reshaping.tile(in_output_wide, n_unique_param_values, axis=1) else: # This in-place output hasn't been provided, so create empty _in_output_settings = resolve_dict(in_output_settings[i]) dtype = _in_output_settings.get("dtype", None) if per_column: in_output_shape = input_shape_ready else: in_output_shape = (input_shape_2d[0], input_shape_2d[1] * n_unique_param_values) in_output_wide = np.empty(in_output_shape, dtype=dtype) in_output_list[i] = in_output_wide # Split each in-place output into chunks, each of input shape, and append to a list in_outputs = [] if per_column: in_outputs.append(in_output_wide) else: for p in range(n_unique_param_values): if isinstance(in_output_wide, pd.DataFrame): in_output = in_output_wide.iloc[:, p * input_shape_2d[1] : (p + 1) * input_shape_2d[1]] if len(input_shape_ready) == 1: in_output = in_output.iloc[:, 0] else: in_output = in_output_wide[:, p * input_shape_2d[1] : (p + 1) * input_shape_2d[1]] if len(input_shape_ready) == 1: in_output = in_output[:, 0] if keep_pd and isinstance(in_output, np.ndarray): in_output = ArrayWrapper(input_index, input_columns, in_output.ndim).wrap(in_output) in_outputs.append(in_output) in_output_list_ready.append(in_outputs) if checks.is_numba_func(custom_func): # Numba can't stand untyped lists in_output_list_ready = [to_typed_list(in_outputs) for in_outputs in in_output_list_ready] def _use_raw(_raw): # Use raw results of previous run to build outputs _output_list, _param_map, _n_input_cols, _other_list = _raw idxs = np.array([_param_map.index(param_tuple) for param_tuple in zip(*param_list)]) _output_list = [ np.hstack([o[:, idx * _n_input_cols : (idx + 1) * _n_input_cols] for idx in idxs]) for o in _output_list ] return _output_list, _param_map, _n_input_cols, _other_list # Get raw results if use_raw is not None: # Use raw results of previous run to build outputs output_list, param_map, n_input_cols, other_list = _use_raw(use_raw) else: # Prepare other arguments func_args = args func_kwargs = dict(kwargs) if pass_input_shape: func_kwargs["input_shape"] = input_shape_ready if pass_wrapper: func_kwargs["wrapper"] = wrapper_ready if pass_param_index: func_kwargs["param_index"] = param_index if pass_final_index: func_kwargs["final_index"] = final_index if pass_single_comb: func_kwargs["single_comb"] = single_comb if pass_per_column: func_kwargs["per_column"] = per_column # Substitute templates if has_templates(func_args) or has_templates(func_kwargs): template_context = merge_dicts( broadcast_named_args, dict( input_shape=input_shape_ready, wrapper=wrapper_ready, **dict(zip(input_names, input_list_ready)), **dict(zip(in_output_names, in_output_list_ready)), **dict(zip(param_names, param_list_ready)), pre_sub_args=func_args, pre_sub_kwargs=func_kwargs, ), template_context, ) func_args = substitute_templates(func_args, template_context, eval_id="custom_func_args") func_kwargs = substitute_templates(func_kwargs, template_context, eval_id="custom_func_kwargs") # Run the custom function if checks.is_numba_func(custom_func): func_args += tuple(func_kwargs.values()) func_kwargs = {} if pass_packed: outputs = custom_func( tuple(input_list_ready), tuple(in_output_list_ready), tuple(param_list_ready), *func_args, **func_kwargs, ) else: outputs = custom_func( *input_list_ready, *in_output_list_ready, *param_list_ready, *func_args, **func_kwargs ) # Return outputs if isinstance(return_raw, str): if return_raw.lower() == "outputs": if use_run_unique and not silence_warnings: warn("Raw outputs are produced by unique parameter combinations when run_unique=True") return outputs else: raise ValueError(f"Invalid return_raw: '{return_raw}'") # Return cache if kwargs.get("return_cache", False): if use_run_unique and not silence_warnings: warn("Cache is produced by unique parameter combinations when run_unique=True") return outputs # Post-process results if outputs is None: output_list = [] other_list = [] else: if isinstance(outputs, (tuple, list, List)): output_list = list(outputs) else: output_list = [outputs] # Other outputs should be returned without post-processing (for example cache_dict) if len(output_list) > num_ret_outputs: other_list = output_list[num_ret_outputs:] if use_run_unique and not silence_warnings: warn( "Additional output objects are produced by unique parameter combinations " "when run_unique=True" ) else: other_list = [] # Process only the num_ret_outputs outputs output_list = output_list[:num_ret_outputs] if len(output_list) != num_ret_outputs: raise ValueError("Number of returned outputs other than expected") output_list = list(map(lambda x: reshaping.to_2d_array(x), output_list)) # In-place outputs are treated as outputs from here output_list = in_output_list + output_list # Prepare raw param_map = list(zip(*param_list_unique)) # account for use_run_unique output_shape = output_list[0].shape for output in output_list: if output.shape != output_shape: raise ValueError("All outputs must have the same shape") if per_column: n_input_cols = output_shape[1] else: n_input_cols = output_shape[1] // n_unique_param_values if input_shape_2d is not None: if n_input_cols != input_shape_2d[1]: if per_column: raise ValueError( "All outputs must have the same number of columns as inputs when per_column=True" ) else: raise ValueError( "All outputs must have the same number of columns as there " "are input columns times parameter combinations" ) raw = output_list, param_map, n_input_cols, other_list if return_raw: if use_run_unique and not silence_warnings: warn("Raw outputs are produced by unique parameter combinations when run_unique=True") return raw if use_run_unique: output_list, param_map, n_input_cols, other_list = _use_raw(raw) # Update shape and other meta if no inputs if input_shape is None: if n_input_cols == 1: input_shape = (output_list[0].shape[0],) else: input_shape = (output_list[0].shape[0], n_input_cols) if input_index is None: input_index = pd.RangeIndex(start=0, step=1, stop=input_shape[0]) if input_columns is None: input_columns = pd.RangeIndex(start=0, step=1, stop=input_shape[1] if len(input_shape) > 1 else 1) # Build column hierarchy for indicator instance if len(param_list) > 0: if final_index is None: # Build new column levels on top of input levels build_columns_meta = build_columns( param_list, input_columns, level_names=level_names, hide_levels=hide_levels, single_value=single_value, param_settings=param_settings, per_column=per_column, **build_col_kwargs, ) rep_param_indexes = build_columns_meta["rep_param_indexes"] final_index = build_columns_meta["final_index"] # Build a mapper that maps old columns in inputs to new columns # Instead of tiling all inputs to the shape of outputs and wasting memory, # we just keep a mapper and perform the tiling when needed input_mapper = None if len(input_list) > 0: if per_column: input_mapper = np.arange(len(input_columns)) else: input_mapper = np.tile(np.arange(len(input_columns)), n_param_values) # Build mappers to easily map between parameters and columns mapper_list = [rep_param_indexes[i] for i in range(len(param_list))] else: # Some indicators don't have any params final_index = input_columns input_mapper = None mapper_list = [] # Return artifacts: no pandas objects, just a wrapper and NumPy arrays new_ndim = len(input_shape) if output_list[0].shape[1] == 1 else output_list[0].ndim if new_ndim == 1 and not single_comb: new_ndim = 2 wrapper = ArrayWrapper(input_index, final_index, new_ndim, **wrapper_kwargs) return ( wrapper, input_list, input_mapper, output_list[: len(in_output_list)], output_list[len(in_output_list) :], param_list, mapper_list, other_list, ) @classmethod def _run(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunOutputT: """Private run method.""" raise NotImplementedError @classmethod def run(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunOutputT: """Public run method.""" return cls._run(*args, **kwargs) @classmethod def _run_combs(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunCombsOutputT: """Private run combinations method.""" raise NotImplementedError @classmethod def run_combs(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunCombsOutputT: """Public run combinations method.""" return cls._run_combs(*args, **kwargs) @hybrid_method def row_stack( cls_or_self: tp.MaybeType[IndicatorBaseT], *objs: tp.MaybeTuple[IndicatorBaseT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> IndicatorBaseT: """Stack multiple `IndicatorBase` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers. All objects to be merged must have the same columns x parameters.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, IndicatorBase): raise TypeError("Each object to be merged must be an instance of IndicatorBase") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack( *[obj.wrapper for obj in objs], stack_columns=False, **wrapper_kwargs ) if "input_list" not in kwargs: new_input_list = [] for input_name in cls.input_names: new_input_list.append(row_stack_arrays([getattr(obj, f"_{input_name}") for obj in objs])) kwargs["input_list"] = new_input_list if "in_output_list" not in kwargs: new_in_output_list = [] for in_output_name in cls.in_output_names: new_in_output_list.append(row_stack_arrays([getattr(obj, f"_{in_output_name}") for obj in objs])) kwargs["in_output_list"] = new_in_output_list if "output_list" not in kwargs: new_output_list = [] for output_name in cls.output_names: new_output_list.append(row_stack_arrays([getattr(obj, f"_{output_name}") for obj in objs])) kwargs["output_list"] = new_output_list kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[IndicatorBaseT], *objs: tp.MaybeTuple[IndicatorBaseT], wrapper_kwargs: tp.KwargsLike = None, reindex_kwargs: tp.KwargsLike = None, **kwargs, ) -> IndicatorBaseT: """Stack multiple `IndicatorBase` instances along columns x parameters. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers. All objects to be merged must have the same index.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, IndicatorBase): raise TypeError("Each object to be merged must be an instance of IndicatorBase") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], **wrapper_kwargs, ) if "input_mapper" not in kwargs: stack_input_mapper_objs = True for obj in objs: if getattr(obj, "_input_mapper", None) is None: stack_input_mapper_objs = False break if stack_input_mapper_objs: kwargs["input_mapper"] = np.concatenate([getattr(obj, "_input_mapper") for obj in objs]) if "in_output_list" not in kwargs: new_in_output_list = [] for in_output_name in cls.in_output_names: new_in_output_list.append(column_stack_arrays([getattr(obj, f"_{in_output_name}") for obj in objs])) kwargs["in_output_list"] = new_in_output_list if "output_list" not in kwargs: new_output_list = [] for output_name in cls.output_names: new_output_list.append(column_stack_arrays([getattr(obj, f"_{output_name}") for obj in objs])) kwargs["output_list"] = new_output_list if "param_list" not in kwargs: new_param_list = [] for param_name in cls.param_names: param_objs = [] for obj in objs: param_objs.extend(getattr(obj, f"_{param_name}_list")) new_param_list.append(param_objs) kwargs["param_list"] = new_param_list if "mapper_list" not in kwargs: new_mapper_list = [] for param_name in cls.param_names: new_mapper = None for obj in objs: obj_mapper = getattr(obj, f"_{param_name}_mapper") if new_mapper is None: new_mapper = obj_mapper else: new_mapper = new_mapper.append(obj_mapper) new_mapper_list.append(new_mapper) kwargs["mapper_list"] = new_mapper_list kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, input_list: InputListT, input_mapper: InputMapperT, in_output_list: InOutputListT, output_list: OutputListT, param_list: ParamListT, mapper_list: MapperListT, short_name: str, **kwargs, ) -> None: if input_mapper is not None: checks.assert_equal(input_mapper.shape[0], wrapper.shape_2d[1]) for ts in input_list: checks.assert_equal(ts.shape[0], wrapper.shape_2d[0]) for ts in in_output_list + output_list: checks.assert_equal(ts.shape, wrapper.shape_2d) for params in param_list: checks.assert_len_equal(param_list[0], params) for mapper in mapper_list: checks.assert_equal(len(mapper), wrapper.shape_2d[1]) checks.assert_instance_of(short_name, str) if "level_names" in kwargs: del kwargs["level_names"] # deprecated Analyzable.__init__( self, wrapper, input_list=input_list, input_mapper=input_mapper, in_output_list=in_output_list, output_list=output_list, param_list=param_list, mapper_list=mapper_list, short_name=short_name, **kwargs, ) setattr(self, "_short_name", short_name) for i, ts_name in enumerate(self.input_names): setattr(self, f"_{ts_name}", input_list[i]) setattr(self, "_input_mapper", input_mapper) for i, in_output_name in enumerate(self.in_output_names): setattr(self, f"_{in_output_name}", in_output_list[i]) for i, output_name in enumerate(self.output_names): setattr(self, f"_{output_name}", output_list[i]) for i, param_name in enumerate(self.param_names): setattr(self, f"_{param_name}_list", param_list[i]) setattr(self, f"_{param_name}_mapper", mapper_list[i]) # Initialize indexers mapper_sr_list = [] for i, m in enumerate(mapper_list): mapper_sr_list.append(pd.Series(m, index=wrapper.columns)) tuple_mapper = self._tuple_mapper if tuple_mapper is not None: level_names = tuple(tuple_mapper.names) mapper_sr_list.append(pd.Series(tuple_mapper.tolist(), index=wrapper.columns)) else: level_names = () self._level_names = level_names for base_cls in type(self).__bases__: if base_cls.__name__ == "ParamIndexer": base_cls.__init__(self, mapper_sr_list, level_names=[*level_names, level_names]) @property def _tuple_mapper(self) -> tp.Optional[pd.MultiIndex]: """Mapper of multiple parameters.""" if len(self.param_names) <= 1: return None return pd.MultiIndex.from_arrays([getattr(self, f"_{name}_mapper") for name in self.param_names]) @property def _param_mapper(self) -> tp.Optional[pd.Index]: """Mapper of all parameters.""" if len(self.param_names) == 0: return None if len(self.param_names) == 1: return getattr(self, f"_{self.param_names[0]}_mapper") return self._tuple_mapper @property def _visible_param_mapper(self) -> tp.Optional[pd.Index]: """Mapper of visible parameters.""" if len(self.param_names) == 0: return None if len(self.param_names) == 1: mapper = getattr(self, f"_{self.param_names[0]}_mapper") if mapper.name is None: return None if mapper.name not in self.wrapper.columns.names: return None return mapper mapper = self._tuple_mapper visible_indexes = [] for i, name in enumerate(mapper.names): if name is not None and name in self.wrapper.columns.names: visible_indexes.append(mapper.get_level_values(i)) if len(visible_indexes) == 0: return None if len(visible_indexes) == 1: return visible_indexes[0] return pd.MultiIndex.from_arrays(visible_indexes) def indexing_func(self: IndicatorBaseT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> IndicatorBaseT: """Perform indexing on `IndicatorBase`.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) row_idxs = wrapper_meta["row_idxs"] col_idxs = wrapper_meta["col_idxs"] rows_changed = wrapper_meta["rows_changed"] columns_changed = wrapper_meta["columns_changed"] if not isinstance(row_idxs, slice): row_idxs = reshaping.to_1d_array(row_idxs) if not isinstance(col_idxs, slice): col_idxs = reshaping.to_1d_array(col_idxs) input_mapper = getattr(self, "_input_mapper", None) if input_mapper is not None: if columns_changed: input_mapper = input_mapper[col_idxs] input_list = [] for input_name in self.input_names: new_input = ArrayWrapper.select_from_flex_array( getattr(self, f"_{input_name}"), row_idxs=row_idxs, col_idxs=col_idxs if input_mapper is None else None, rows_changed=rows_changed, columns_changed=columns_changed if input_mapper is None else False, ) input_list.append(new_input) in_output_list = [] for in_output_name in self.in_output_names: new_in_output = ArrayWrapper.select_from_flex_array( getattr(self, f"_{in_output_name}"), row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) in_output_list.append(new_in_output) output_list = [] for output_name in self.output_names: new_output = ArrayWrapper.select_from_flex_array( getattr(self, f"_{output_name}"), row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) output_list.append(new_output) param_list = [] for param_name in self.param_names: param_list.append(getattr(self, f"_{param_name}_list")) mapper_list = [] for param_name in self.param_names: # Tuple mapper is a list because of its complex data type mapper_list.append(getattr(self, f"_{param_name}_mapper")[col_idxs]) return self.replace( wrapper=wrapper_meta["new_wrapper"], input_list=input_list, input_mapper=input_mapper, in_output_list=in_output_list, output_list=output_list, param_list=param_list, mapper_list=mapper_list, ) @class_property def short_name(cls_or_self) -> str: """Name of the indicator.""" return cls_or_self._short_name @class_property def input_names(cls_or_self) -> tp.Tuple[str, ...]: """Names of the input arrays.""" return cls_or_self._input_names @class_property def param_names(cls_or_self) -> tp.Tuple[str, ...]: """Names of the parameters.""" return cls_or_self._param_names @class_property def in_output_names(cls_or_self) -> tp.Tuple[str, ...]: """Names of the in-place output arrays.""" return cls_or_self._in_output_names @class_property def output_names(cls_or_self) -> tp.Tuple[str, ...]: """Names of the regular output arrays.""" return cls_or_self._output_names @class_property def lazy_output_names(cls_or_self) -> tp.Tuple[str, ...]: """Names of the lazy output arrays.""" return cls_or_self._lazy_output_names @class_property def output_flags(cls_or_self) -> tp.Kwargs: """Dictionary of output flags.""" return cls_or_self._output_flags @class_property def param_defaults(cls_or_self) -> tp.Dict[str, tp.Any]: """Parameter defaults extracted from the signature of `IndicatorBase.run`.""" func_kwargs = get_func_kwargs(cls_or_self.run) out = {} for k, v in func_kwargs.items(): if k in cls_or_self.param_names: if isinstance(v, Default): out[k] = v.value else: out[k] = v return out @property def level_names(self) -> tp.Tuple[str]: """Column level names corresponding to each parameter.""" return self._level_names def unpack(self) -> tp.MaybeTuple[tp.SeriesFrame]: """Return outputs, either one output or a tuple if there are multiple.""" out = tuple([getattr(self, name) for name in self.output_names]) if len(out) == 1: out = out[0] return out def to_dict(self, include_all: bool = True) -> tp.Dict[str, tp.SeriesFrame]: """Return outputs as a dict.""" if include_all: output_names = self.output_names + self.in_output_names + self.lazy_output_names else: output_names = self.output_names return {name: getattr(self, name) for name in output_names} def to_frame(self, include_all: bool = True) -> tp.Frame: """Return outputs as a DataFrame.""" out = self.to_dict(include_all=include_all) return pd.concat(list(out.values()), axis=1, keys=pd.Index(list(out.keys()), name="output")) def get(self, key: tp.Optional[tp.Hashable] = None) -> tp.Optional[tp.SeriesFrame]: """Get a time series.""" if key is None: return self.main_output return getattr(self, key) def dropna(self: IndicatorBaseT, include_all: bool = True, **kwargs) -> IndicatorBaseT: """Drop missing values. Keyword arguments are passed to `pd.Series.dropna` or `pd.DataFrame.dropna`.""" df = self.to_frame(include_all=include_all) new_df = df.dropna(**kwargs) if new_df.index.equals(df.index): return self return self.loc[new_df.index] def rename(self: IndicatorBaseT, short_name: str) -> IndicatorBaseT: """Replace the short name of the indicator.""" new_level_names = () for level_name in self.level_names: if level_name.startswith(self.short_name + "_"): level_name = level_name.replace(self.short_name, short_name, 1) new_level_names += (level_name,) new_mapper_list = [] for i, param_name in enumerate(self.param_names): mapper = getattr(self, f"_{param_name}_mapper") new_mapper_list.append(mapper.rename(new_level_names[i])) new_columns = self.wrapper.columns for i, name in enumerate(self.wrapper.columns.names): if name in self.level_names: new_columns = new_columns.rename({name: new_level_names[self.level_names.index(name)]}) new_wrapper = self.wrapper.replace(columns=new_columns) return self.replace(wrapper=new_wrapper, mapper_list=new_mapper_list, short_name=short_name) def rename_levels( self: IndicatorBaseT, mapper: tp.MaybeMappingSequence[tp.Level], **kwargs, ) -> IndicatorBaseT: new_self = Analyzable.rename_levels(self, mapper, **kwargs) old_column_names = self.wrapper.columns.names new_column_names = new_self.wrapper.columns.names new_level_names = () for level_name in new_self.level_names: if level_name in old_column_names: level_name = new_column_names[old_column_names.index(level_name)] new_level_names += (level_name,) new_mapper_list = [] for i, param_name in enumerate(new_self.param_names): mapper = getattr(new_self, f"_{param_name}_mapper") new_mapper_list.append(mapper.rename(new_level_names[i])) return new_self.replace(mapper_list=new_mapper_list, level_names=new_level_names) # ############# Iteration ############# # def items( self, group_by: tp.GroupByLike = "params", apply_group_by: bool = False, keep_2d: bool = False, key_as_index: bool = False, ) -> tp.Items: """Iterate over columns (or groups if grouped and `Wrapping.group_select` is True). Allows the following additional options for `group_by`: "all_params", "params" (only those that aren't hidden), and parameter names.""" if isinstance(group_by, str): if group_by not in self.wrapper.columns.names: if group_by.lower() == "all_params": group_by = self._param_mapper elif group_by.lower() == "params": group_by = self._visible_param_mapper elif group_by in self.param_names: group_by = getattr(self, f"_{group_by}_mapper") elif isinstance(group_by, (tuple, list)): new_group_by = [] for g in group_by: if isinstance(g, str): if g not in self.wrapper.columns.names: if g in self.param_names: g = getattr(self, f"_{g}_mapper") new_group_by.append(g) group_by = type(group_by)(new_group_by) for k, v in Analyzable.items( self, group_by=group_by, apply_group_by=apply_group_by, keep_2d=keep_2d, key_as_index=key_as_index, ): yield k, v # ############# Documentation ############# # @classmethod def fix_docstrings(cls, __pdoc__: dict) -> None: """Fix docstrings.""" if hasattr(cls, "custom_func"): if cls.__name__ + ".custom_func" not in __pdoc__: __pdoc__[cls.__name__ + ".custom_func"] = "Custom function." if hasattr(cls, "apply_func"): if cls.__name__ + ".apply_func" not in __pdoc__: __pdoc__[cls.__name__ + ".apply_func"] = "Apply function." if hasattr(cls, "cache_func"): if cls.__name__ + ".cache_func" not in __pdoc__: __pdoc__[cls.__name__ + ".cache_func"] = "Cache function." if hasattr(cls, "entry_place_func_nb"): if cls.__name__ + ".entry_place_func_nb" not in __pdoc__: __pdoc__[cls.__name__ + ".entry_place_func_nb"] = "Entry placement function." if hasattr(cls, "exit_place_func_nb"): if cls.__name__ + ".exit_place_func_nb" not in __pdoc__: __pdoc__[cls.__name__ + ".exit_place_func_nb"] = "Exit placement function." class IndicatorFactory(Configured): """A factory for creating new indicators. Initialize `IndicatorFactory` to create a skeleton and then use a class method such as `IndicatorFactory.with_custom_func` to bind a calculation function to the skeleton. Args: class_name (str): Name for the created indicator class. class_docstring (str): Docstring for the created indicator class. module_name (str): Name of the module the class originates from. short_name (str): Short name of the indicator. Defaults to lower-case `class_name`. prepend_name (bool): Whether to prepend `short_name` to each parameter level. input_names (list of str): List with input names. param_names (list of str): List with parameter names. in_output_names (list of str): List with in-output names. An in-place output is an output that is not returned but modified in-place. Some advantages of such outputs include: 1) they don't need to be returned, 2) they can be passed between functions as easily as inputs, 3) they can be provided with already allocated data to safe memory, 4) if data or default value are not provided, they are created empty to not occupy memory. output_names (list of str): List with output names. output_flags (dict): Dictionary of in-place and regular output flags. lazy_outputs (dict): Dictionary with user-defined functions that will be bound to the indicator class and wrapped with `property` if not already wrapped. attr_settings (dict): Dictionary with attribute settings. Attributes can be `input_names`, `in_output_names`, `output_names`, and `lazy_outputs`. Following keys are accepted: * `dtype`: Data type used to determine which methods to generate around this attribute. Set to None to disable. Default is `float_`. Can be set to instance of `collections.namedtuple` acting as enumerated type, or any other mapping; It will then create a property with suffix `readable` that contains data in a string format. * `enum_unkval`: Value to be considered as unknown. Applies to enumerated data types only. * `make_cacheable`: Whether to make the property cacheable. Applies to inputs only. metrics (dict): Metrics supported by `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`. If dict, will be converted to `vectorbtpro.utils.config.Config`. stats_defaults (callable or dict): Defaults for `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`. If dict, will be converted into a property. subplots (dict): Subplots supported by `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`. If dict, will be converted to `vectorbtpro.utils.config.Config`. plots_defaults (callable or dict): Defaults for `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`. If dict, will be converted into a property. **kwargs: Custom keyword arguments passed to the config. !!! note The `__init__` method is not used for running the indicator, for this use `run`. The reason for this is indexing, which requires a clean `__init__` method for creating a new indicator object with newly indexed attributes. """ def __init__( self, class_name: tp.Optional[str] = None, class_docstring: tp.Optional[str] = None, module_name: tp.Optional[str] = __name__, short_name: tp.Optional[str] = None, prepend_name: bool = True, input_names: tp.Optional[tp.Sequence[str]] = None, param_names: tp.Optional[tp.Sequence[str]] = None, in_output_names: tp.Optional[tp.Sequence[str]] = None, output_names: tp.Optional[tp.Sequence[str]] = None, output_flags: tp.KwargsLike = None, lazy_outputs: tp.KwargsLike = None, attr_settings: tp.KwargsLike = None, metrics: tp.Optional[tp.Kwargs] = None, stats_defaults: tp.Union[None, tp.Callable, tp.Kwargs] = None, subplots: tp.Optional[tp.Kwargs] = None, plots_defaults: tp.Union[None, tp.Callable, tp.Kwargs] = None, **kwargs, ) -> None: Configured.__init__( self, class_name=class_name, class_docstring=class_docstring, module_name=module_name, short_name=short_name, prepend_name=prepend_name, input_names=input_names, param_names=param_names, in_output_names=in_output_names, output_names=output_names, output_flags=output_flags, lazy_outputs=lazy_outputs, attr_settings=attr_settings, metrics=metrics, stats_defaults=stats_defaults, subplots=subplots, plots_defaults=plots_defaults, **kwargs, ) # Check parameters if class_name is None: class_name = "Indicator" checks.assert_instance_of(class_name, str) if class_docstring is None: class_docstring = "" checks.assert_instance_of(class_docstring, str) if module_name is not None: checks.assert_instance_of(module_name, str) if short_name is None: if class_name == "Indicator": short_name = "custom" else: short_name = class_name.lower() checks.assert_instance_of(short_name, str) checks.assert_instance_of(prepend_name, bool) if input_names is None: input_names = [] else: checks.assert_sequence(input_names) input_names = list(input_names) if param_names is None: param_names = [] else: checks.assert_sequence(param_names) param_names = list(param_names) if in_output_names is None: in_output_names = [] else: checks.assert_sequence(in_output_names) in_output_names = list(in_output_names) if output_names is None: output_names = [] else: checks.assert_sequence(output_names) output_names = list(output_names) all_output_names = in_output_names + output_names if len(all_output_names) == 0: raise ValueError("Must have at least one in-place or regular output") if len(set.intersection(set(input_names), set(in_output_names), set(output_names))) > 0: raise ValueError("Inputs, in-outputs, and parameters must all have unique names") if output_flags is None: output_flags = {} checks.assert_instance_of(output_flags, dict) if len(output_flags) > 0: checks.assert_dict_valid(output_flags, all_output_names) if lazy_outputs is None: lazy_outputs = {} checks.assert_instance_of(lazy_outputs, dict) if attr_settings is None: attr_settings = {} checks.assert_instance_of(attr_settings, dict) all_attr_names = input_names + all_output_names + list(lazy_outputs.keys()) if len(attr_settings) > 0: checks.assert_dict_valid( attr_settings, [ all_attr_names, ["dtype", "enum_unkval", "make_cacheable"], ], ) # Set up class ParamIndexer = build_param_indexer( param_names + (["tuple"] if len(param_names) > 1 else []), module_name=module_name, ) Indicator = type(class_name, (IndicatorBase, ParamIndexer), {}) Indicator.__doc__ = class_docstring if module_name is not None: Indicator.__module__ = module_name # Create read-only properties setattr(Indicator, "_short_name", short_name) setattr(Indicator, "_input_names", tuple(input_names)) setattr(Indicator, "_param_names", tuple(param_names)) setattr(Indicator, "_in_output_names", tuple(in_output_names)) setattr(Indicator, "_output_names", tuple(output_names)) setattr(Indicator, "_lazy_output_names", tuple(lazy_outputs.keys())) setattr(Indicator, "_output_flags", output_flags) for param_name in param_names: def param_list_prop(self, _param_name=param_name) -> tp.List[tp.ParamValue]: return getattr(self, f"_{_param_name}_list") param_list_prop.__doc__ = f"List of `{param_name}` values." setattr(Indicator, f"{param_name}_list", property(param_list_prop)) for input_name in input_names: _attr_settings = attr_settings.get(input_name, {}) make_cacheable = _attr_settings.get("make_cacheable", False) def input_prop(self, _input_name: str = input_name) -> tp.SeriesFrame: """Input array.""" old_input = reshaping.to_2d_array(getattr(self, "_" + _input_name)) input_mapper = getattr(self, "_input_mapper") if input_mapper is None: return self.wrapper.wrap(old_input) return self.wrapper.wrap(old_input[:, input_mapper]) input_prop.__name__ = input_name input_prop.__module__ = Indicator.__module__ input_prop.__qualname__ = f"{Indicator.__name__}.{input_prop.__name__}" if make_cacheable: setattr(Indicator, input_name, cacheable_property(input_prop)) else: setattr(Indicator, input_name, property(input_prop)) for output_name in all_output_names: def output_prop(self, _output_name: str = output_name) -> tp.SeriesFrame: return self.wrapper.wrap(getattr(self, "_" + _output_name)) if output_name in in_output_names: output_prop.__doc__ = """In-place output array.""" else: output_prop.__doc__ = """Output array.""" output_prop.__name__ = output_name output_prop.__module__ = Indicator.__module__ output_prop.__qualname__ = f"{Indicator.__name__}.{output_prop.__name__}" if output_name in output_flags: _output_flags = output_flags[output_name] if isinstance(_output_flags, (tuple, list)): _output_flags = ", ".join(_output_flags) output_prop.__doc__ += "\n\n" + _output_flags setattr(Indicator, output_name, property(output_prop)) # Add user-defined outputs for prop_name, prop in lazy_outputs.items(): prop.__name__ = prop_name prop.__module__ = Indicator.__module__ prop.__qualname__ = f"{Indicator.__name__}.{prop.__name__}" if prop.__doc__ is None: prop.__doc__ = f"""Custom property.""" if not isinstance(prop, property): prop = property(prop) setattr(Indicator, prop_name, prop) # Add comparison & combination methods for all inputs, outputs, and user-defined properties def assign_combine_method( func_name: str, combine_func: tp.Callable, def_kwargs: tp.Kwargs, attr_name: str, docstring: str, ) -> None: def combine_method( self: IndicatorBaseT, other: tp.MaybeTupleList[tp.Union[IndicatorBaseT, tp.ArrayLike, BaseAccessor]], level_name: tp.Optional[str] = None, allow_multiple: bool = True, _prepend_name: bool = prepend_name, **kwargs, ) -> tp.SeriesFrame: if allow_multiple and isinstance(other, (tuple, list)): other = list(other) for i in range(len(other)): if isinstance(other[i], IndicatorBase): other[i] = getattr(other[i], attr_name) elif isinstance(other[i], str): other[i] = getattr(self, other[i]) else: if isinstance(other, IndicatorBase): other = getattr(other, attr_name) elif isinstance(other, str): other = getattr(self, other) if level_name is None: if _prepend_name: if attr_name == self.short_name: level_name = f"{self.short_name}_{func_name}" else: level_name = f"{self.short_name}_{attr_name}_{func_name}" else: level_name = f"{attr_name}_{func_name}" out = combine_objs( getattr(self, attr_name), other, combine_func, level_name=level_name, allow_multiple=allow_multiple, **merge_dicts(def_kwargs, kwargs), ) return out combine_method.__name__ = f"{attr_name}_{func_name}" combine_method.__module__ = Indicator.__module__ combine_method.__qualname__ = f"{Indicator.__name__}.{combine_method.__name__}" combine_method.__doc__ = docstring setattr(Indicator, f"{attr_name}_{func_name}", combine_method) for attr_name in all_attr_names: _attr_settings = attr_settings.get(attr_name, {}) dtype = _attr_settings.get("dtype", float_) enum_unkval = _attr_settings.get("enum_unkval", -1) if checks.is_mapping_like(dtype): def attr_readable( self, _attr_name: str = attr_name, _mapping: tp.MappingLike = dtype, _enum_unkval: tp.Any = enum_unkval, ) -> tp.SeriesFrame: return getattr(self, _attr_name).vbt(mapping=_mapping).apply_mapping(enum_unkval=_enum_unkval) attr_readable.__name__ = f"{attr_name}_readable" attr_readable.__module__ = Indicator.__module__ attr_readable.__qualname__ = f"{Indicator.__name__}.{attr_readable.__name__}" attr_readable.__doc__ = inspect.cleandoc( """`{attr_name}` in readable format based on the following mapping: ```python {dtype} ```""" ).format(attr_name=attr_name, dtype=prettify(to_value_mapping(dtype, enum_unkval=enum_unkval))) setattr(Indicator, f"{attr_name}_readable", property(attr_readable)) def attr_stats( self, *args, _attr_name: str = attr_name, _mapping: tp.MappingLike = dtype, **kwargs, ) -> tp.SeriesFrame: return getattr(self, _attr_name).vbt(mapping=_mapping).stats(*args, **kwargs) attr_stats.__name__ = f"{attr_name}_stats" attr_stats.__module__ = Indicator.__module__ attr_stats.__qualname__ = f"{Indicator.__name__}.{attr_stats.__name__}" attr_stats.__doc__ = inspect.cleandoc( """Stats of `{attr_name}` based on the following mapping: ```python {dtype} ```""" ).format(attr_name=attr_name, dtype=prettify(to_value_mapping(dtype))) setattr(Indicator, f"{attr_name}_stats", attr_stats) elif np.issubdtype(dtype, np.number): func_info = [ ("above", np.greater, dict()), ("below", np.less, dict()), ("equal", np.equal, dict()), ( "crossed_above", lambda x, y, wait=0, dropna=False: jit_reg.resolve(generic_nb.crossed_above_nb)( x, y, wait=wait, dropna=dropna, ), dict(to_2d=True), ), ( "crossed_below", lambda x, y, wait=0, dropna=False: jit_reg.resolve(generic_nb.crossed_above_nb)( y, x, wait=wait, dropna=dropna, ), dict(to_2d=True), ), ] for func_name, np_func, def_kwargs in func_info: method_docstring = f"""Return True for each element where `{attr_name}` is {func_name} `other`. See `vectorbtpro.indicators.factory.combine_objs`.""" assign_combine_method(func_name, np_func, def_kwargs, attr_name, method_docstring) def attr_stats(self, *args, _attr_name: str = attr_name, **kwargs) -> tp.SeriesFrame: return getattr(self, _attr_name).vbt.stats(*args, **kwargs) attr_stats.__name__ = f"{attr_name}_stats" attr_stats.__module__ = Indicator.__module__ attr_stats.__qualname__ = f"{Indicator.__name__}.{attr_stats.__name__}" attr_stats.__doc__ = f"""Stats of `{attr_name}` as generic.""" setattr(Indicator, f"{attr_name}_stats", attr_stats) elif np.issubdtype(dtype, np.bool_): func_info = [ ("and", np.logical_and, dict()), ("or", np.logical_or, dict()), ("xor", np.logical_xor, dict()), ] for func_name, np_func, def_kwargs in func_info: method_docstring = f"""Return `{attr_name} {func_name.upper()} other`. See `vectorbtpro.indicators.factory.combine_objs`.""" assign_combine_method(func_name, np_func, def_kwargs, attr_name, method_docstring) def attr_stats(self, *args, _attr_name: str = attr_name, **kwargs) -> tp.SeriesFrame: return getattr(self, _attr_name).vbt.signals.stats(*args, **kwargs) attr_stats.__name__ = f"{attr_name}_stats" attr_stats.__module__ = Indicator.__module__ attr_stats.__qualname__ = f"{Indicator.__name__}.{attr_stats.__name__}" attr_stats.__doc__ = f"""Stats of `{attr_name}` as signals.""" setattr(Indicator, f"{attr_name}_stats", attr_stats) # Prepare stats if metrics is not None: if not isinstance(metrics, Config): metrics = Config(metrics, options_=dict(copy_kwargs=dict(copy_mode="deep"))) setattr(Indicator, "_metrics", metrics.copy()) if stats_defaults is not None: if isinstance(stats_defaults, dict): def stats_defaults_prop(self, _stats_defaults: tp.Kwargs = stats_defaults) -> tp.Kwargs: return _stats_defaults else: def stats_defaults_prop(self, _stats_defaults: tp.Kwargs = stats_defaults) -> tp.Kwargs: return stats_defaults(self) stats_defaults_prop.__name__ = "stats_defaults" stats_defaults_prop.__module__ = Indicator.__module__ stats_defaults_prop.__qualname__ = f"{Indicator.__name__}.{stats_defaults_prop.__name__}" setattr(Indicator, "stats_defaults", property(stats_defaults_prop)) # Prepare plots if subplots is not None: if not isinstance(subplots, Config): subplots = Config(subplots, options_=dict(copy_kwargs=dict(copy_mode="deep"))) setattr(Indicator, "_subplots", subplots.copy()) if plots_defaults is not None: if isinstance(plots_defaults, dict): def plots_defaults_prop(self, _plots_defaults: tp.Kwargs = plots_defaults) -> tp.Kwargs: return _plots_defaults else: def plots_defaults_prop(self, _plots_defaults: tp.Kwargs = plots_defaults) -> tp.Kwargs: return plots_defaults(self) plots_defaults_prop.__name__ = "plots_defaults" plots_defaults_prop.__module__ = Indicator.__module__ plots_defaults_prop.__qualname__ = f"{Indicator.__name__}.{plots_defaults_prop.__name__}" setattr(Indicator, "plots_defaults", property(plots_defaults_prop)) # Store arguments self._class_name = class_name self._class_docstring = class_docstring self._module_name = module_name self._short_name = short_name self._prepend_name = prepend_name self._input_names = input_names self._param_names = param_names self._in_output_names = in_output_names self._output_names = output_names self._output_flags = output_flags self._lazy_outputs = lazy_outputs self._attr_settings = attr_settings self._metrics = metrics self._stats_defaults = stats_defaults self._subplots = subplots self._plots_defaults = plots_defaults # Store indicator class self._Indicator = Indicator @property def class_name(self) -> str: """Name for the created indicator class.""" return self._class_name @property def class_docstring(self) -> str: """Docstring for the created indicator class.""" return self._class_docstring @property def module_name(self) -> str: """Name of the module the class originates from.""" return self._module_name @property def short_name(self) -> str: """Short name of the indicator.""" return self._short_name @property def prepend_name(self) -> bool: """Whether to prepend `IndicatorFactory.short_name` to each parameter level.""" return self._prepend_name @property def input_names(self) -> tp.List[str]: """List with input names.""" return self._input_names @property def param_names(self) -> tp.List[str]: """List with parameter names.""" return self._param_names @property def in_output_names(self) -> tp.List[str]: """List with in-output names.""" return self._in_output_names @property def output_names(self) -> tp.List[str]: """List with output names.""" return self._output_names @property def output_flags(self) -> tp.Kwargs: """Dictionary of in-place and regular output flags.""" return self._output_flags @property def lazy_outputs(self) -> tp.Kwargs: """Dictionary with user-defined functions that will become properties.""" return self._lazy_outputs @property def attr_settings(self) -> tp.Kwargs: """Dictionary with attribute settings.""" return self._attr_settings @property def metrics(self) -> Config: """Metrics supported by `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`.""" return self._metrics @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`.""" return self._stats_defaults @property def subplots(self) -> Config: """Subplots supported by `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`.""" return self._subplots @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`.""" return self._plots_defaults @property def Indicator(self) -> tp.Type[IndicatorBase]: """Built indicator class.""" return self._Indicator # ############# Construction ############# # def with_custom_func( self, custom_func: tp.Callable, require_input_shape: bool = False, param_settings: tp.KwargsLike = None, in_output_settings: tp.KwargsLike = None, hide_params: tp.Union[None, bool, tp.Sequence[str]] = None, hide_default: bool = True, var_args: bool = False, keyword_only_args: bool = False, **pipeline_kwargs, ) -> tp.Type[IndicatorBase]: """Build indicator class around a custom calculation function. In contrast to `IndicatorFactory.with_apply_func`, this method offers full flexibility. It's up to the user to handle caching and concatenate columns for each parameter (for example, by using `vectorbtpro.base.combining.apply_and_concat`). Also, you must ensure that each output array has an appropriate number of columns, which is the number of columns in input arrays multiplied by the number of parameter combinations. Args: custom_func (callable): A function that takes broadcast arrays corresponding to `input_names`, broadcast in-place output arrays corresponding to `in_output_names`, broadcast parameter arrays corresponding to `param_names`, and other arguments and keyword arguments, and returns outputs corresponding to `output_names` and other objects that are then returned with the indicator instance. Can be Numba-compiled. !!! note Shape of each output must be the same and match the shape of each input stacked n times (= the number of parameter values) along the column axis. require_input_shape (bool): Whether to input shape is required. param_settings (dict): A dictionary of parameter settings keyed by name. See `IndicatorBase.run_pipeline` for keys. Can be overwritten by any run method. in_output_settings (dict): A dictionary of in-place output settings keyed by name. See `IndicatorBase.run_pipeline` for keys. Can be overwritten by any run method. hide_params (bool or list of str): Parameter names to hide column levels for, or whether to hide all parameters. Can be overwritten by any run method. hide_default (bool): Whether to hide column levels of parameters with default value. Can be overwritten by any run method. var_args (bool): Whether run methods should accept variable arguments (`*args`). Set to True if `custom_func` accepts positional agruments that are not listed in the config. keyword_only_args (bool): Whether run methods should accept keyword-only arguments (`*`). Set to True to force the user to use keyword arguments (e.g., to avoid misplacing arguments). **pipeline_kwargs: Keyword arguments passed to `IndicatorBase.run_pipeline`. Can be overwritten by any run method. Can contain default values and also references to other arguments wrapped with `vectorbtpro.base.reshaping.Ref`. Returns: `Indicator`, and optionally other objects that are returned by `custom_func` and exceed `output_names`. Usage: * The following example produces the same indicator as the `IndicatorFactory.with_apply_func` example. ```pycon >>> @njit >>> def apply_func_nb(i, ts1, ts2, p1, p2, arg1, arg2): ... return ts1 * p1[i] + arg1, ts2 * p2[i] + arg2 >>> @njit ... def custom_func(ts1, ts2, p1, p2, arg1, arg2): ... return vbt.base.combining.apply_and_concat_multiple_nb( ... len(p1), apply_func_nb, ts1, ts2, p1, p2, arg1, arg2) >>> MyInd = vbt.IF( ... input_names=['ts1', 'ts2'], ... param_names=['p1', 'p2'], ... output_names=['o1', 'o2'] ... ).with_custom_func(custom_func, var_args=True, arg2=200) >>> myInd = MyInd.run(price, price * 2, [1, 2], [3, 4], 100) >>> myInd.o1 custom_p1 1 2 custom_p2 3 4 a b a b 2020-01-01 101.0 105.0 102.0 110.0 2020-01-02 102.0 104.0 104.0 108.0 2020-01-03 103.0 103.0 106.0 106.0 2020-01-04 104.0 102.0 108.0 104.0 2020-01-05 105.0 101.0 110.0 102.0 >>> myInd.o2 custom_p1 1 2 custom_p2 3 4 a b a b 2020-01-01 206.0 230.0 208.0 240.0 2020-01-02 212.0 224.0 216.0 232.0 2020-01-03 218.0 218.0 224.0 224.0 2020-01-04 224.0 212.0 232.0 216.0 2020-01-05 230.0 206.0 240.0 208.0 ``` The difference between `apply_func_nb` here and in `IndicatorFactory.with_apply_func` is that here it takes the index of the current parameter combination that can be used for parameter selection. * You can also remove the entire `apply_func_nb` and define your logic in `custom_func` (which shouldn't necessarily be Numba-compiled): ```pycon >>> @njit ... def custom_func(ts1, ts2, p1, p2, arg1, arg2): ... input_shape = ts1.shape ... n_params = len(p1) ... out1 = np.empty((input_shape[0], input_shape[1] * n_params), dtype=float_) ... out2 = np.empty((input_shape[0], input_shape[1] * n_params), dtype=float_) ... for k in range(n_params): ... for col in range(input_shape[1]): ... for i in range(input_shape[0]): ... out1[i, input_shape[1] * k + col] = ts1[i, col] * p1[k] + arg1 ... out2[i, input_shape[1] * k + col] = ts2[i, col] * p2[k] + arg2 ... return out1, out2 ``` """ Indicator = self.Indicator short_name = self.short_name prepend_name = self.prepend_name input_names = self.input_names param_names = self.param_names in_output_names = self.in_output_names output_names = self.output_names all_input_names = input_names + param_names + in_output_names setattr(Indicator, "custom_func", custom_func) def _split_args( args: tp.Sequence, ) -> tp.Tuple[tp.Dict[str, tp.ArrayLike], tp.Dict[str, tp.ArrayLike], tp.Dict[str, tp.ParamValues], tp.Args]: inputs = dict(zip(input_names, args[: len(input_names)])) checks.assert_len_equal(inputs, input_names) args = args[len(input_names) :] params = dict(zip(param_names, args[: len(param_names)])) checks.assert_len_equal(params, param_names) args = args[len(param_names) :] in_outputs = dict(zip(in_output_names, args[: len(in_output_names)])) checks.assert_len_equal(in_outputs, in_output_names) args = args[len(in_output_names) :] if not var_args and len(args) > 0: raise TypeError( "Variable length arguments are not supported by this function (var_args is set to False)" ) return inputs, in_outputs, params, args for k, v in pipeline_kwargs.items(): if k in param_names and not isinstance(v, Default): pipeline_kwargs[k] = Default(v) # track default params pipeline_kwargs = merge_dicts({k: None for k in in_output_names}, pipeline_kwargs) # Display default parameters and in-place outputs in the signature default_kwargs = {} for k in list(pipeline_kwargs.keys()): if k in input_names or k in param_names or k in in_output_names: default_kwargs[k] = pipeline_kwargs.pop(k) if var_args and keyword_only_args: raise ValueError("var_args and keyword_only_args cannot be used together") # Add private run method def_run_kwargs = dict( short_name=short_name, hide_params=hide_params, hide_default=hide_default, **default_kwargs, ) def _run(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunOutputT: _short_name = kwargs.pop("short_name", def_run_kwargs["short_name"]) _hide_params = kwargs.pop("hide_params", def_run_kwargs["hide_params"]) _hide_default = kwargs.pop("hide_default", def_run_kwargs["hide_default"]) _param_settings = merge_dicts(param_settings, kwargs.pop("param_settings", {})) _in_output_settings = merge_dicts(in_output_settings, kwargs.pop("in_output_settings", {})) if isinstance(_hide_params, bool): if not _hide_params: _hide_params = None else: _hide_params = param_names if _hide_params is None: _hide_params = [] args = list(args) # Split arguments inputs, in_outputs, params, args = _split_args(args) # Prepare column levels level_names = [] hide_levels = [] for pname in param_names: level_name = _short_name + "_" + pname if prepend_name else pname level_names.append(level_name) if pname in _hide_params or (_hide_default and isinstance(params[pname], Default)): hide_levels.append(level_name) for k, v in params.items(): if isinstance(v, Default): params[k] = v.value # Run the pipeline results = Indicator.run_pipeline( len(output_names), # number of returned outputs custom_func, *args, require_input_shape=require_input_shape, inputs=inputs, in_outputs=in_outputs, params=params, level_names=level_names, hide_levels=hide_levels, param_settings=_param_settings, in_output_settings=_in_output_settings, **merge_dicts(pipeline_kwargs, kwargs), ) # Return the raw result if any of the flags are set if kwargs.get("return_raw", False) or kwargs.get("return_cache", False): return results # Unpack the result ( wrapper, new_input_list, input_mapper, in_output_list, output_list, new_param_list, mapper_list, other_list, ) = results # Create a new instance obj = cls( wrapper, new_input_list, input_mapper, in_output_list, output_list, new_param_list, mapper_list, short_name, ) if len(other_list) > 0: return (obj, *tuple(other_list)) return obj setattr(Indicator, "_run", classmethod(_run)) # Add public run method # Create function dynamically to provide user with a proper signature def compile_run_function(func_name: str, docstring: str, _default_kwargs: tp.KwargsLike = None) -> tp.Callable: pos_names = [] main_kw_names = [] other_kw_names = [] if _default_kwargs is None: _default_kwargs = {} for k in input_names + param_names: if k in _default_kwargs: main_kw_names.append(k) else: pos_names.append(k) main_kw_names.extend(in_output_names) # in_output_names are keyword-only for k, v in _default_kwargs.items(): if k not in pos_names and k not in main_kw_names: other_kw_names.append(k) _0 = func_name _1 = "*, " if keyword_only_args else "" _2 = [] if require_input_shape: _2.append("input_shape") _2.extend(pos_names) _2 = ", ".join(_2) + ", " if len(_2) > 0 else "" _3 = "*args, " if var_args else "" _4 = ["{}={}".format(k, k) for k in main_kw_names + other_kw_names] if require_input_shape: _4 += ["input_index=None", "input_columns=None"] _4 = ", ".join(_4) + ", " if len(_4) > 0 else "" _5 = docstring _6 = all_input_names _6 = ", ".join(_6) + ", " if len(_6) > 0 else "" _7 = [] if require_input_shape: _7.append("input_shape") _7.extend(other_kw_names) _7 = ["{}={}".format(k, k) for k in _7] if require_input_shape: _7 += ["input_index=input_index", "input_columns=input_columns"] _7 = ", ".join(_7) + ", " if len(_7) > 0 else "" func_str = ( "@classmethod\n" "def {0}(cls, {1}{2}{3}{4}**kwargs):\n" ' """{5}"""\n' " return cls._{0}({6}{3}{7}**kwargs)".format(_0, _1, _2, _3, _4, _5, _6, _7) ) scope = {**dict(Default=Default), **_default_kwargs} filename = inspect.getfile(lambda: None) code = compile(func_str, filename, "single") exec(code, scope) return scope[func_name] _0 = self.class_name _1 = "" if len(self.input_names) > 0: _1 += "\n* Inputs: " + ", ".join(map(lambda x: f"`{x}`", self.input_names)) if len(self.in_output_names) > 0: _1 += "\n* In-place outputs: " + ", ".join(map(lambda x: f"`{x}`", self.in_output_names)) if len(self.param_names) > 0: _1 += "\n* Parameters: " + ", ".join(map(lambda x: f"`{x}`", self.param_names)) if len(self.output_names) > 0: _1 += "\n* Outputs: " + ", ".join(map(lambda x: f"`{x}`", self.output_names)) if len(self.lazy_outputs) > 0: _1 += "\n* Lazy outputs: " + ", ".join(map(lambda x: f"`{x}`", list(self.lazy_outputs.keys()))) run_docstring = """Run `{0}` indicator. {1} Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all. Set `hide_default` to False to show the column levels of the parameters with a default value. Other keyword arguments are passed to `{0}.run_pipeline`.""".format( _0, _1, ) run = compile_run_function("run", run_docstring, def_run_kwargs) run.__name__ = "run" run.__module__ = Indicator.__module__ run.__qualname__ = f"{Indicator.__name__}.{run.__name__}" setattr(Indicator, "run", run) if len(param_names) > 0: # Add private run_combs method def_run_combs_kwargs = dict( r=2, param_product=False, comb_func=itertools.combinations, run_unique=True, short_names=None, hide_params=hide_params, hide_default=hide_default, **default_kwargs, ) def _run_combs(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunCombsOutputT: _r = kwargs.pop("r", def_run_combs_kwargs["r"]) _param_product = kwargs.pop("param_product", def_run_combs_kwargs["param_product"]) _comb_func = kwargs.pop("comb_func", def_run_combs_kwargs["comb_func"]) _run_unique = kwargs.pop("run_unique", def_run_combs_kwargs["run_unique"]) _short_names = kwargs.pop("short_names", def_run_combs_kwargs["short_names"]) _hide_params = kwargs.pop("hide_params", def_run_kwargs["hide_params"]) _hide_default = kwargs.pop("hide_default", def_run_kwargs["hide_default"]) _param_settings = merge_dicts(param_settings, kwargs.get("param_settings", {})) if isinstance(_hide_params, bool): if not _hide_params: _hide_params = None else: _hide_params = param_names if _hide_params is None: _hide_params = [] if _short_names is None: _short_names = [f"{short_name}_{str(i + 1)}" for i in range(_r)] args = list(args) # Split arguments inputs, in_outputs, params, args = _split_args(args) # Hide params for pname in param_names: if _hide_default and isinstance(params[pname], Default): params[pname] = params[pname].value if pname not in _hide_params: _hide_params.append(pname) checks.assert_len_equal(params, param_names) # Bring argument to list format input_list = list(inputs.values()) in_output_list = list(in_outputs.values()) param_list = list(params.values()) # Prepare params for i, pname in enumerate(param_names): is_tuple = _param_settings.get(pname, {}).get("is_tuple", False) is_array_like = _param_settings.get(pname, {}).get("is_array_like", False) param_list[i] = params_to_list(params[pname], is_tuple, is_array_like) if _param_product: param_list = create_param_product(param_list) else: param_list = broadcast_params(param_list) # Speed up by pre-calculating raw outputs if _run_unique: raw_results = cls._run( *input_list, *param_list, *in_output_list, *args, return_raw=True, run_unique=False, **kwargs, ) kwargs["use_raw"] = raw_results # use them next time # Generate indicator instances instances = [] if _comb_func == itertools.product: param_lists = zip(*_comb_func(zip(*param_list), repeat=_r)) else: param_lists = zip(*_comb_func(zip(*param_list), _r)) for i, param_list in enumerate(param_lists): instances.append( cls._run( *input_list, *zip(*param_list), *in_output_list, *args, short_name=_short_names[i], hide_params=_hide_params, hide_default=_hide_default, run_unique=False, **kwargs, ) ) return tuple(instances) setattr(Indicator, "_run_combs", classmethod(_run_combs)) # Add public run_combs method _0 = self.class_name _1 = "" if len(self.input_names) > 0: _1 += "\n* Inputs: " + ", ".join(map(lambda x: f"`{x}`", self.input_names)) if len(self.in_output_names) > 0: _1 += "\n* In-place outputs: " + ", ".join(map(lambda x: f"`{x}`", self.in_output_names)) if len(self.param_names) > 0: _1 += "\n* Parameters: " + ", ".join(map(lambda x: f"`{x}`", self.param_names)) if len(self.output_names) > 0: _1 += "\n* Outputs: " + ", ".join(map(lambda x: f"`{x}`", self.output_names)) if len(self.lazy_outputs) > 0: _1 += "\n* Lazy outputs: " + ", ".join(map(lambda x: f"`{x}`", list(self.lazy_outputs.keys()))) run_combs_docstring = """Create a combination of multiple `{0}` indicators using function `comb_func`. {1} `comb_func` must accept an iterable of parameter tuples and `r`. Also accepts all combinatoric iterators from itertools such as `itertools.combinations`. Pass `r` to specify how many indicators to run. Pass `short_names` to specify the short name for each indicator. Set `run_unique` to True to first compute raw outputs for all parameters, and then use them to build each indicator (faster). Other keyword arguments are passed to `{0}.run`. !!! note This method should only be used when multiple indicators are needed. To test multiple parameters, pass them as lists to `{0}.run`. """.format( _0, _1, ) run_combs = compile_run_function("run_combs", run_combs_docstring, def_run_combs_kwargs) run_combs.__name__ = "run_combs" run_combs.__module__ = Indicator.__module__ run_combs.__qualname__ = f"{Indicator.__name__}.{run_combs.__name__}" setattr(Indicator, "run_combs", run_combs) return Indicator def with_apply_func( self, apply_func: tp.Callable, cache_func: tp.Optional[tp.Callable] = None, takes_1d: bool = False, select_params: bool = True, pass_packed: bool = False, cache_pass_packed: tp.Optional[bool] = None, pass_per_column: bool = False, cache_pass_per_column: tp.Optional[bool] = None, forward_skipna: bool = False, kwargs_as_args: tp.Optional[tp.Iterable[str]] = None, jit_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Union[str, tp.Type[IndicatorBase]]: """Build indicator class around a custom apply function. In contrast to `IndicatorFactory.with_custom_func`, this method handles a lot of things for you, such as caching, parameter selection, and concatenation. Your part is writing a function `apply_func` that accepts a selection of parameters (single values as opposed to multiple values in `IndicatorFactory.with_custom_func`) and does the calculation. It then automatically concatenates the resulting arrays into a single array per output. While this approach is simpler, it's also less flexible, since we can only work with one parameter selection at a time and can't view all parameters. The execution and concatenation is performed using `vectorbtpro.base.combining.apply_and_concat_each`. !!! note If `apply_func` is a Numba-compiled function: * All inputs are automatically converted to NumPy arrays * Each argument in `*args` must be of a Numba-compatible type * You cannot pass keyword arguments * Your outputs must be arrays of the same shape, data type and data order !!! note Reserved arguments such as `per_column` (in this order) get passed as positional arguments if `jitted_loop` is True, otherwise as keyword arguments. Args: apply_func (callable): A function that takes inputs, selection of parameters, and other arguments, and does calculations to produce outputs. Arguments are passed to `apply_func` in the following order: * `i` (index of the parameter combination) if `select_params` is set to False * `input_shape` if `pass_input_shape` is set to True and `input_shape` not in `kwargs_as_args` * Input arrays corresponding to `input_names`. Passed as a tuple if `pass_packed`, otherwise unpacked. If `select_params` is True, each argument is a list composed of multiple arrays - one per parameter combination. When `per_column` is True, each of those arrays corresponds to a column. Otherwise, they all refer to the same array. If `takes_1d`, each array gets additionally split into multiple column arrays. Still passed as a single array to the caching function. * In-output arrays corresponding to `in_output_names`. Passed as a tuple if `pass_packed`, otherwise unpacked. If `select_params` is True, each argument is a list composed of multiple arrays - one per parameter combination. When `per_column` is True, each of those arrays corresponds to a column. If `takes_1d`, each array gets additionally split into multiple column arrays. Still passed as a single array to the caching function. * Parameters corresponding to `param_names`. Passed as a tuple if `pass_packed`, otherwise unpacked. If `select_params` is True, each argument is a list composed of multiple values - one per parameter combination. When `per_column` is True, each of those values corresponds to a column. If `takes_1d`, each value gets additionally repeated by the number of columns in the input arrays. * Variable arguments if `var_args` is set to True * `per_column` if `pass_per_column` is set to True and `per_column` not in `kwargs_as_args` and `jitted_loop` is set to True * Arguments listed in `kwargs_as_args` passed as positional. Can include `takes_1d` and `per_column`. * Other keyword arguments if `jitted_loop` is False. Also includes `takes_1d` and `per_column` if they must be passed and not in `kwargs_as_args`. Can be Numba-compiled (but doesn't have to). !!! note Shape of each output must be the same and match the shape of each input. cache_func (callable): A caching function to preprocess data beforehand. Takes the same arguments as `apply_func`. Must return a single object or a tuple of objects. All returned objects will be passed unpacked as last arguments to `apply_func`. Can be Numba-compiled (but doesn't have to). takes_1d (bool): Whether to split 2-dim arrays into multiple 1-dim arrays along the column axis. Gets applied on inputs and in-outputs, while parameters get repeated by the number of columns. select_params (bool): Whether to automatically select in-outputs and parameters. If False, prepends the current iteration index to the arguments. pass_packed (bool): Whether to pass packed tuples for inputs, in-place outputs, and parameters. cache_pass_packed (bool): Overrides `pass_packed` for the caching function. pass_per_column (bool): Whether to pass `per_column`. cache_pass_per_column (bool): Overrides `pass_per_column` for the caching function. forward_skipna (bool): Whether to forward `skipna` to the apply function. kwargs_as_args (iterable of str): Keyword arguments from `kwargs` dict to pass as positional arguments to the apply function. Should be used together with `jitted_loop` set to True since Numba doesn't support variable keyword arguments. Defaults to []. Order matters. jit_kwargs (dict): Keyword arguments passed to `@njit` decorator of the parameter selection function. By default, has `nogil` set to True. **kwargs: Keyword arguments passed to `IndicatorFactory.with_custom_func`, all the way down to `vectorbtpro.base.combining.apply_and_concat_each`. Returns: Indicator Usage: * The following example produces the same indicator as the `IndicatorFactory.with_custom_func` example. ```pycon >>> @njit ... def apply_func_nb(ts1, ts2, p1, p2, arg1, arg2): ... return ts1 * p1 + arg1, ts2 * p2 + arg2 >>> MyInd = vbt.IF( ... input_names=['ts1', 'ts2'], ... param_names=['p1', 'p2'], ... output_names=['out1', 'out2'] ... ).with_apply_func( ... apply_func_nb, var_args=True, ... kwargs_as_args=['arg2'], arg2=200) >>> myInd = MyInd.run(price, price * 2, [1, 2], [3, 4], 100) >>> myInd.out1 custom_p1 1 2 custom_p2 3 4 a b a b 2020-01-01 101.0 105.0 102.0 110.0 2020-01-02 102.0 104.0 104.0 108.0 2020-01-03 103.0 103.0 106.0 106.0 2020-01-04 104.0 102.0 108.0 104.0 2020-01-05 105.0 101.0 110.0 102.0 >>> myInd.out2 custom_p1 1 2 custom_p2 3 4 a b a b 2020-01-01 206.0 230.0 208.0 240.0 2020-01-02 212.0 224.0 216.0 232.0 2020-01-03 218.0 218.0 224.0 224.0 2020-01-04 224.0 212.0 232.0 216.0 2020-01-05 230.0 206.0 240.0 208.0 ``` * To change the execution engine or specify other engine-related arguments, use `execute_kwargs`: ```pycon >>> import time >>> def apply_func(ts, p): ... time.sleep(1) ... return ts * p >>> MyInd = vbt.IF( ... input_names=['ts'], ... param_names=['p'], ... output_names=['out'] ... ).with_apply_func(apply_func) >>> %timeit MyInd.run(price, [1, 2, 3]) 3.02 s ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) >>> %timeit MyInd.run(price, [1, 2, 3], execute_kwargs=dict(engine='dask')) 1.02 s ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` """ Indicator = self.Indicator setattr(Indicator, "apply_func", apply_func) setattr(Indicator, "cache_func", cache_func) module_name = self.module_name input_names = self.input_names output_names = self.output_names in_output_names = self.in_output_names param_names = self.param_names num_ret_outputs = len(output_names) if kwargs_as_args is None: kwargs_as_args = [] if checks.is_numba_func(apply_func): # Build a function that selects a parameter tuple # Do it here to avoid compilation with Numba every time custom_func is run _0 = "i" _0 += ", args_before" if len(input_names) > 0: _0 += ", " + ", ".join(input_names) if len(in_output_names) > 0: _0 += ", " + ", ".join(in_output_names) if len(param_names) > 0: _0 += ", " + ", ".join(param_names) _0 += ", *args" if select_params: _1 = "*args_before" else: _1 = "i, *args_before" if pass_packed: if len(input_names) > 0: _1 += ", (" + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), input_names)) + ",)" else: _1 += ", ()" if len(in_output_names) > 0: _1 += ", (" + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), in_output_names)) + ",)" else: _1 += ", ()" if len(param_names) > 0: _1 += ", (" + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), param_names)) + ",)" else: _1 += ", ()" else: if len(input_names) > 0: _1 += ", " + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), input_names)) if len(in_output_names) > 0: _1 += ", " + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), in_output_names)) if len(param_names) > 0: _1 += ", " + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), param_names)) _1 += ", *args" func_str = "def param_select_func_nb({0}):\n return apply_func({1})".format(_0, _1) scope = {"apply_func": apply_func} filename = inspect.getfile(lambda: None) code = compile(func_str, filename, "single") exec(code, scope) param_select_func_nb = scope["param_select_func_nb"] param_select_func_nb.__doc__ = "Parameter selection function." if module_name is not None: param_select_func_nb.__module__ = module_name jit_kwargs = merge_dicts(dict(nogil=True), jit_kwargs) param_select_func_nb = njit(param_select_func_nb, **jit_kwargs) setattr(Indicator, "param_select_func_nb", param_select_func_nb) def custom_func( input_tuple: tp.Tuple[tp.AnyArray, ...], in_output_tuple: tp.Tuple[tp.List[tp.AnyArray], ...], param_tuple: tp.Tuple[tp.List[tp.ParamValue], ...], *_args, input_shape: tp.Optional[tp.Shape] = None, per_column: bool = False, split_columns: bool = False, skipna: bool = False, return_cache: bool = False, use_cache: tp.Union[bool, CacheOutputT] = True, jitted_loop: bool = False, jitted_warmup: bool = False, param_index: tp.Optional[tp.Index] = None, final_index: tp.Optional[tp.Index] = None, single_comb: bool = False, execute_kwargs: tp.KwargsLike = None, **_kwargs, ) -> tp.Union[None, CacheOutputT, tp.Array2d, tp.List[tp.Array2d]]: """Custom function that forwards inputs and parameters to `apply_func`.""" if jitted_loop and not checks.is_numba_func(apply_func): raise ValueError("Apply function must be Numba-compiled for jitted_loop=True") if skipna and len(in_output_tuple) > 1: raise ValueError("NaNs cannot be skipped for in-outputs") if skipna and jitted_loop: raise ValueError("NaNs cannot be skipped when jitted_loop=True") if forward_skipna: _kwargs["skipna"] = skipna skipna = False if execute_kwargs is None: execute_kwargs = {} else: execute_kwargs = dict(execute_kwargs) _cache_pass_packed = cache_pass_packed _cache_pass_per_column = cache_pass_per_column # Prepend positional arguments args_before = () if input_shape is not None and "input_shape" not in kwargs_as_args: if per_column or takes_1d: args_before += ((input_shape[0],),) elif split_columns and len(input_shape) == 2: args_before += ((input_shape[0], 1),) else: args_before += (input_shape,) # Append positional arguments more_args = () for k in kwargs_as_args: if k == "per_column": value = per_column elif k == "takes_1d": value = per_column elif k == "split_columns": value = per_column else: value = _kwargs.pop(k) # important: remove from kwargs more_args += (value,) # Resolve the number of parameters if len(input_tuple) > 0: if input_tuple[0].ndim == 1: n_cols = 1 else: n_cols = input_tuple[0].shape[1] elif input_shape is not None: if len(input_shape) == 1: n_cols = 1 else: n_cols = input_shape[1] else: n_cols = None if per_column: n_params = n_cols else: n_params = len(param_tuple[0]) if len(param_tuple) > 0 else 1 # Caching cache = use_cache if isinstance(cache, bool): if cache and cache_func is not None: _input_tuple = input_tuple _in_output_tuple = () for in_outputs in in_output_tuple: if checks.is_numba_func(cache_func): _in_outputs = to_typed_list(in_outputs) else: _in_outputs = in_outputs _in_output_tuple += (_in_outputs,) _param_tuple = () for params in param_tuple: if checks.is_numba_func(cache_func): _params = to_typed_list(params) else: _params = params _param_tuple += (_params,) if _cache_pass_packed is None: _cache_pass_packed = pass_packed if _cache_pass_per_column is None and per_column: _cache_pass_per_column = True if _cache_pass_per_column is None: _cache_pass_per_column = pass_per_column cache_more_args = tuple(more_args) cache_kwargs = dict(_kwargs) if _cache_pass_per_column: if "per_column" not in kwargs_as_args: if jitted_loop: cache_more_args += (per_column,) else: cache_kwargs["per_column"] = per_column if _cache_pass_packed: cache = cache_func( *args_before, _input_tuple, _in_output_tuple, _param_tuple, *_args, *cache_more_args, **cache_kwargs, ) else: cache = cache_func( *args_before, *_input_tuple, *_in_output_tuple, *_param_tuple, *_args, *cache_more_args, **cache_kwargs, ) else: cache = None if return_cache: return cache if cache is None: cache = () if not isinstance(cache, tuple): cache = (cache,) # Prepare inputs def _expand_input(input: tp.MaybeList[tp.AnyArray], multiple: bool = False) -> tp.List[tp.AnyArray]: if jitted_loop: _inputs = List() else: _inputs = [] if per_column: if multiple: _input = input[0] else: _input = input if _input.ndim == 2: for i in range(_input.shape[1]): if takes_1d: if isinstance(_input, pd.DataFrame): _inputs.append(_input.iloc[:, i]) else: _inputs.append(_input[:, i]) else: if isinstance(_input, pd.DataFrame): _inputs.append(_input.iloc[:, i : i + 1]) else: _inputs.append(_input[:, i : i + 1]) else: _inputs.append(_input) else: for p in range(n_params): if multiple: _input = input[p] else: _input = input if takes_1d or split_columns: if isinstance(_input, pd.DataFrame): for i in range(_input.shape[1]): if takes_1d: _inputs.append(_input.iloc[:, i]) else: _inputs.append(_input.iloc[:, i : i + 1]) elif _input.ndim == 2: for i in range(_input.shape[1]): if takes_1d: _inputs.append(_input[:, i]) else: _inputs.append(_input[:, i : i + 1]) else: _inputs.append(_input) else: _inputs.append(_input) return _inputs _input_tuple = () for input in input_tuple: _inputs = _expand_input(input) _input_tuple += (_inputs,) if skipna and len(_input_tuple) > 0: new_input_tuple = tuple([[] for _ in range(len(_input_tuple))]) nan_masks = [] any_nan = False for i in range(len(_input_tuple[0])): inputs = [] for k in range(len(_input_tuple)): input = _input_tuple[k][i] if input.ndim == 2 and input.shape[1] > 1: raise ValueError( "NaNs cannot be skipped for multi-columnar inputs. Use split_columns=True." ) inputs.append(input) nan_mask = build_nan_mask(*inputs) nan_masks.append(nan_mask) if not any_nan and nan_mask is not None and np.any(nan_mask): any_nan = True if any_nan: inputs = squeeze_nan(*inputs, nan_mask=nan_mask) for k in range(len(_input_tuple)): new_input_tuple[k].append(inputs[k]) _input_tuple = new_input_tuple if any_nan: def post_execute_func(outputs, nan_masks): new_outputs = [] for i, output in enumerate(outputs): if isinstance(output, (tuple, list, List)): output = unsqueeze_nan(*output, nan_mask=nan_masks[i]) else: output = unsqueeze_nan(output, nan_mask=nan_masks[i]) new_outputs.append(output) return new_outputs if "post_execute_func" in execute_kwargs: raise ValueError("Cannot use custom post_execute_func when skipna=True") execute_kwargs["post_execute_func"] = post_execute_func if "post_execute_kwargs" in execute_kwargs: raise ValueError("Cannot use custom post_execute_kwargs when skipna=True") execute_kwargs["post_execute_kwargs"] = dict(outputs=Rep("results"), nan_masks=nan_masks) if "post_execute_on_sorted" in execute_kwargs: raise ValueError("Cannot use custom post_execute_on_sorted when skipna=True") execute_kwargs["post_execute_on_sorted"] = True _in_output_tuple = () for in_outputs in in_output_tuple: _in_outputs = _expand_input(in_outputs, multiple=True) _in_output_tuple += (_in_outputs,) _param_tuple = () for params in param_tuple: if not per_column and (takes_1d or split_columns): _params = [params[p] for p in range(len(params)) for _ in range(n_cols)] else: _params = params if jitted_loop: if len(_params) > 0 and np.isscalar(_params[0]): _params = np.asarray(_params) else: _params = to_typed_list(_params) _param_tuple += (_params,) if not per_column and (takes_1d or split_columns): _n_params = n_params * n_cols keys = final_index else: _n_params = n_params keys = param_index execute_kwargs = merge_dicts(dict(show_progress=False if single_comb else None), execute_kwargs) execute_kwargs["keys"] = keys if pass_per_column: if "per_column" not in kwargs_as_args: if jitted_loop: more_args += (per_column,) else: _kwargs["per_column"] = per_column # Apply function and concatenate outputs if jitted_loop: return combining.apply_and_concat( _n_params, param_select_func_nb, args_before, *_input_tuple, *_in_output_tuple, *_param_tuple, *_args, *more_args, *cache, **_kwargs, n_outputs=num_ret_outputs, jitted_loop=True, jitted_warmup=jitted_warmup, execute_kwargs=execute_kwargs, ) tasks = [] for i in range(_n_params): if select_params: _inputs = tuple(_inputs[i] for _inputs in _input_tuple) _in_outputs = tuple(_in_outputs[i] for _in_outputs in _in_output_tuple) _params = tuple(_params[i] for _params in _param_tuple) else: _inputs = _input_tuple _in_outputs = _in_output_tuple _params = _param_tuple tasks.append( Task( apply_func, *((i,) if not select_params else ()), *args_before, *((_inputs,) if pass_packed else _inputs), *((_in_outputs,) if pass_packed else _in_outputs), *((_params,) if pass_packed else _params), *_args, *more_args, *cache, **_kwargs, ) ) return combining.apply_and_concat_each( tasks, n_outputs=num_ret_outputs, execute_kwargs=execute_kwargs, ) return self.with_custom_func( custom_func, pass_packed=True, pass_param_index=True, pass_final_index=True, pass_single_comb=True, **kwargs, ) # ############# Exploration ############# # _custom_indicators: tp.ClassVar[Config] = HybridConfig() @class_property def custom_indicators(cls) -> Config: """Custom indicators keyed by custom locations.""" return cls._custom_indicators @classmethod def list_custom_locations(cls) -> tp.List[str]: """List custom locations. Appear in the order they were registered.""" return list(cls.custom_indicators.keys()) @classmethod def list_builtin_locations(cls) -> tp.List[str]: """List built-in locations. Appear in the order as defined by the author.""" return [ "vbt", "talib_func", "talib", "pandas_ta", "ta", "technical", "techcon", "smc", "wqa101", ] @classmethod def list_locations(cls) -> tp.List[str]: """List all supported locations. First come custom locations, then built-in locations.""" return [*cls.list_custom_locations(), *cls.list_builtin_locations()] @classmethod def match_location(cls, location: str) -> tp.Optional[str]: """Match location.""" for k in cls.list_locations(): if k.lower() == location.lower(): return k return None @classmethod def split_indicator_name(cls, name: str) -> tp.Tuple[tp.Optional[str], tp.Optional[str]]: """Split an indicator name into location and actual name.""" locations = cls.list_locations() matched_location = cls.match_location(name) if matched_location is not None: return matched_location, None if ":" in name: location = name.split(":")[0].strip() name = name.split(":")[1].strip() else: location = None found_location = False if "_" in name: for location in locations: if name.lower().startswith(location.lower() + "_"): found_location = True break if found_location: name = name[len(location) + 1 :] else: location = None return location, name @classmethod def register_custom_indicator( cls, indicator: tp.Union[str, tp.Type[IndicatorBase]], name: tp.Optional[str] = None, location: tp.Optional[str] = None, if_exists: str = "raise", ) -> None: """Register a custom indicator under a custom location. Argument `if_exists` can be "raise", "skip", or "override".""" if isinstance(indicator, str): indicator = cls.get_indicator(indicator) if name is None: name = indicator.__name__ elif location is None: location, name = cls.split_indicator_name(name) if location is None: location = "custom" else: matched_location = cls.match_location(location) if matched_location is not None: location = matched_location if not name.isidentifier(): raise ValueError(f"Custom name '{name}' must be a valid variable name") if not location.isidentifier(): raise ValueError(f"Custom location '{location}' must be a valid variable name") if location in cls.list_builtin_locations(): raise ValueError(f"Custom location '{location}' shadows a built-in location with the same name") if location not in cls.custom_indicators: cls.custom_indicators[location] = dict() for k in cls.custom_indicators[location]: if name.upper() == k.upper(): if if_exists.lower() == "raise": raise ValueError(f"Indicator with name '{name}' already exists under location '{location}'") if if_exists.lower() == "skip": return None if if_exists.lower() == "override": break raise ValueError(f"Invalid if_exists: '{if_exists}'") cls.custom_indicators[location][name] = indicator @classmethod def deregister_custom_indicator( cls, name: tp.Optional[str] = None, location: tp.Optional[str] = None, remove_location: bool = True, ) -> None: """Deregister a custom indicator by its name and location. If `location` is None, deregisters all indicators with the same name across all custom locations.""" if location is not None: matched_location = cls.match_location(location) if matched_location is not None: location = matched_location if name is None: if location is None: for k in list(cls.custom_indicators.keys()): del cls.custom_indicators[k] else: del cls.custom_indicators[location] else: if location is None: location, name = cls.split_indicator_name(name) if location is None: for k, v in list(cls.custom_indicators.items()): for k2 in list(cls.custom_indicators[k].keys()): if name.upper() == k2.upper(): del cls.custom_indicators[k][k2] if remove_location and len(cls.custom_indicators[k]) == 0: del cls.custom_indicators[k] else: for k in list(cls.custom_indicators[location].keys()): if name.upper() == k.upper(): del cls.custom_indicators[location][k] if remove_location and len(cls.custom_indicators[location]) == 0: del cls.custom_indicators[location] @classmethod def get_custom_indicator( cls, name: str, location: tp.Optional[str] = None, return_first: bool = False, ) -> tp.Type[IndicatorBase]: """Get a custom indicator.""" if location is None: location, name = cls.split_indicator_name(name) else: matched_location = cls.match_location(location) if matched_location is not None: location = matched_location name = name.upper() if location is None: found_indicators = [] for k, v in cls.custom_indicators.items(): for k2, v2 in v.items(): k2 = k2.upper() if k2 == name: found_indicators.append(v2) if len(found_indicators) == 1: return found_indicators[0] if len(found_indicators) > 1: if return_first: return found_indicators[0] raise KeyError(f"Found multiple custom indicators with name '{name}'") raise KeyError(f"Found no custom indicator with name '{name}'") else: for k, v in cls.custom_indicators[location].items(): k = k.upper() if k == name: return v raise KeyError(f"Found no custom indicator with name '{name}' under location '{location}'") @classmethod def list_custom_indicators( cls, uppercase: bool = False, location: tp.Optional[str] = None, prepend_location: tp.Optional[bool] = None, ) -> tp.List[str]: """List custom indicators.""" if location is not None: matched_location = cls.match_location(location) if matched_location is not None: location = matched_location locations_names = [] non_custom_location = False for k, v in cls.custom_indicators.items(): if location is not None: if k != location: continue for k2, v2 in v.items(): if uppercase: k2 = k2.upper() if not non_custom_location and k != "custom": non_custom_location = True locations_names.append((k, k2)) locations_names = sorted(locations_names, key=lambda x: (x[0].upper(), x[1])) if prepend_location is None: prepend_location = location is None and non_custom_location if prepend_location: return list(map(lambda x: x[0] + ":" + x[1], locations_names)) return list(map(lambda x: x[1], locations_names)) @classmethod def list_vbt_indicators(cls) -> tp.List[str]: """List all vectorbt indicators.""" import vectorbtpro as vbt return sorted( [ attr for attr in dir(vbt) if not attr.startswith("_") and isinstance(getattr(vbt, attr), type) and getattr(vbt, attr) is not IndicatorBase and issubclass(getattr(vbt, attr), IndicatorBase) ] ) @classmethod def list_indicators( cls, pattern: tp.Optional[str] = None, case_sensitive: bool = False, use_regex: bool = False, location: tp.Optional[str] = None, prepend_location: tp.Optional[bool] = None, ) -> tp.List[str]: """List indicators, optionally matching a pattern. Pattern can also be a location, in such a case all indicators from that location will be returned. For supported locations, see `IndicatorFactory.list_locations`.""" if pattern is not None: if not case_sensitive: pattern = pattern.lower() if location is None and cls.match_location(pattern) is not None: location = pattern pattern = None if prepend_location is None: if location is not None: prepend_location = False else: prepend_location = True with WarningsFiltered(): if location is not None: matched_location = cls.match_location(location) if matched_location is not None: location = matched_location if location in cls.list_custom_locations(): all_indicators = cls.list_custom_indicators(location=location, prepend_location=prepend_location) else: all_indicators = map( lambda x: location + ":" + x if prepend_location else x, getattr(cls, f"list_{location}_indicators")(), ) else: from vectorbtpro.utils.module_ import check_installed all_indicators = [ *cls.list_custom_indicators(prepend_location=prepend_location), *map(lambda x: "vbt:" + x if prepend_location else x, cls.list_vbt_indicators()), *map( lambda x: "talib:" + x if prepend_location else x, cls.list_talib_indicators() if check_installed("talib") else [], ), *map( lambda x: "pandas_ta:" + x if prepend_location else x, cls.list_pandas_ta_indicators() if check_installed("pandas_ta") else [], ), *map( lambda x: "ta:" + x if prepend_location else x, cls.list_ta_indicators() if check_installed("ta") else [], ), *map( lambda x: "technical:" + x if prepend_location else x, cls.list_technical_indicators() if check_installed("technical") else [], ), *map( lambda x: "techcon:" + x if prepend_location else x, cls.list_techcon_indicators() if check_installed("technical") else [], ), *map( lambda x: "smc:" + x if prepend_location else x, cls.list_smc_indicators() if check_installed("smartmoneyconcepts") else [], ), *map(lambda x: "wqa101:" + str(x) if prepend_location else str(x), range(1, 102)), ] found_indicators = [] for indicator in all_indicators: if prepend_location and location is not None: indicator = location + ":" + indicator if case_sensitive: indicator_name = indicator else: indicator_name = indicator.lower() if pattern is not None: if use_regex: if location is not None: if not re.match(pattern, indicator_name): continue else: if not re.match(pattern, indicator_name.split(":")[1]): continue else: if location is not None: if not re.match(fnmatch.translate(pattern), indicator_name): continue else: if not re.match(fnmatch.translate(pattern), indicator_name.split(":")[1]): continue found_indicators.append(indicator) return found_indicators @classmethod def get_indicator(cls, name: str, location: tp.Optional[str] = None) -> tp.Type[IndicatorBase]: """Get the indicator class by its name. The name can contain a location suffix followed by a colon. For example, "talib:sma" or "talib_sma" will return the TA-Lib's SMA. Without a location, the indicator will be searched throughout all indicators, including the vectorbt's ones.""" if location is None: location, name = cls.split_indicator_name(name) else: matched_location = cls.match_location(location) if matched_location is not None: location = matched_location if name is not None: name = name.upper() if location is not None: if location in cls.list_custom_locations(): return cls.get_custom_indicator(name, location=location) if location == "vbt": import vectorbtpro as vbt return getattr(vbt, name.upper()) if location == "talib": return cls.from_talib(name) if location == "pandas_ta": return cls.from_pandas_ta(name) if location == "ta": return cls.from_ta(name) if location == "technical": return cls.from_technical(name) if location == "techcon": return cls.from_techcon(name) if location == "smc": return cls.from_smc(name) if location == "wqa101": return cls.from_wqa101(int(name)) raise ValueError(f"Location '{location}' not found") else: import vectorbtpro as vbt from vectorbtpro.utils.module_ import check_installed if name in cls.list_custom_indicators(uppercase=True, prepend_location=False): return cls.get_custom_indicator(name, return_first=True) if hasattr(vbt, name): return getattr(vbt, name) if str(name).isnumeric(): return cls.from_wqa101(int(name)) if check_installed("smc") and name in cls.list_smc_indicators(): return cls.from_smc(name) if check_installed("technical") and name in cls.list_techcon_indicators(): return cls.from_techcon(name) if check_installed("talib") and name in cls.list_talib_indicators(): return cls.from_talib(name) if check_installed("ta") and name in cls.list_ta_indicators(uppercase=True): return cls.from_ta(name) if check_installed("pandas_ta") and name in cls.list_pandas_ta_indicators(): return cls.from_pandas_ta(name) if check_installed("technical") and name in cls.list_technical_indicators(): return cls.from_technical(name) raise ValueError(f"Indicator '{name}' not found") # ############# Third party ############# # @classmethod def list_talib_indicators(cls) -> tp.List[str]: """List all parseable indicators in `talib`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("talib") import talib return sorted(talib.get_functions()) @classmethod def from_talib(cls, func_name: str, factory_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]: """Build an indicator class around a `talib` function. Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed. For input, parameter and output names, see [docs](https://github.com/mrjbq7/ta-lib/blob/master/docs/index.md). Args: func_name (str): Function name. factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. **kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`. Returns: Indicator Usage: ```pycon >>> SMA = vbt.IF.from_talib('SMA') >>> sma = SMA.run(price, timeperiod=[2, 3]) >>> sma.real sma_timeperiod 2 3 a b a b 2020-01-01 NaN NaN NaN NaN 2020-01-02 1.5 4.5 NaN NaN 2020-01-03 2.5 3.5 2.0 4.0 2020-01-04 3.5 2.5 3.0 3.0 2020-01-05 4.5 1.5 4.0 2.0 ``` * To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`: ```pycon >>> vbt.phelp(SMA.run) SMA.run( close, timeperiod=Default(value=30), timeframe=Default(value=None), short_name='sma', hide_params=None, hide_default=True, **kwargs ): Run `SMA` indicator. * Inputs: `close` * Parameters: `timeperiod`, `timeframe` * Outputs: `real` Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all. Set `hide_default` to False to show the column levels of the parameters with a default value. Other keyword arguments are passed to `SMA.run_pipeline`. ``` * To plot an indicator: ```pycon >>> sma.plot(column=(2, 'a')).show() ``` ![](/assets/images/api/talib_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/talib_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("talib") import talib from talib import abstract from vectorbtpro.indicators.talib_ import talib_func, talib_plot_func func_name = func_name.upper() info = abstract.Function(func_name).info input_names = [] for in_names in info["input_names"].values(): if isinstance(in_names, (list, tuple)): input_names.extend(list(in_names)) else: input_names.append(in_names) class_name = info["name"] class_docstring = "{}, {}".format(info["display_name"], info["group"]) param_names = list(info["parameters"].keys()) + ["timeframe"] output_names = info["output_names"] output_flags = info["output_flags"] _talib_func = talib_func(func_name) _talib_plot_func = talib_plot_func(func_name) def apply_func( input_tuple: tp.Tuple[tp.Array2d, ...], in_output_tuple: tp.Tuple[tp.Array2d, ...], param_tuple: tp.Tuple[tp.ParamValue, ...], timeframe: tp.Optional[tp.FrequencyLike] = None, **_kwargs, ) -> tp.MaybeTuple[tp.Array2d]: if len(param_tuple) == len(param_names): if timeframe is not None: raise ValueError("Time frame is set both as a parameter and as a keyword argument") timeframe = param_tuple[-1] param_tuple = param_tuple[:-1] elif len(param_tuple) > len(param_names): raise ValueError("Provided more parameters than registered") return _talib_func( *input_tuple, *param_tuple, timeframe=timeframe, **_kwargs, ) apply_func.__doc__ = f"""Apply function based on `vbt.talib_func("{func_name}")`.""" kwargs = merge_dicts({k: Default(v) for k, v in info["parameters"].items()}, dict(timeframe=None), kwargs) Indicator = cls( **merge_dicts( dict( class_name=class_name, class_docstring=class_docstring, module_name=__name__ + ".talib", input_names=input_names, param_names=param_names, output_names=output_names, output_flags=output_flags, ), factory_kwargs, ) ).with_apply_func( apply_func, pass_packed=True, pass_wrapper=True, forward_skipna=True, **kwargs, ) def plot( self, column: tp.Optional[tp.Label] = None, add_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.BaseFigure: self_col = self.select_col(column=column) return _talib_plot_func( *[getattr(self_col, output_name) for output_name in output_names], add_shape_kwargs=add_shape_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **kwargs, ) plot.__doc__ = f"""Plot function based on `vbt.talib_plot_func("{func_name}")`.""" setattr(Indicator, "plot", plot) return Indicator @classmethod def parse_pandas_ta_config( cls, func: tp.Callable, test_input_names: tp.Optional[tp.Sequence[str]] = None, test_index_len: int = 100, silence_warnings: bool = False, **kwargs, ) -> tp.Kwargs: """Parse the config of a `pandas_ta` indicator.""" if test_input_names is None: test_input_names = {"open_", "open", "high", "low", "close", "adj_close", "volume", "dividends", "split"} input_names = [] param_names = [] output_names = [] defaults = {} # Parse the function signature of the indicator to get input names sig = inspect.signature(func) for k, v in sig.parameters.items(): if v.kind not in (v.VAR_POSITIONAL, v.VAR_KEYWORD): if v.annotation != inspect.Parameter.empty and v.annotation == pd.Series: input_names.append(k) elif k in test_input_names: input_names.append(k) elif v.default == inspect.Parameter.empty: # Any positional argument is considered input input_names.append(k) else: param_names.append(k) defaults[k] = v.default # To get output names, we need to run the indicator test_df = pd.DataFrame( {c: np.random.uniform(1, 10, size=(test_index_len,)) for c in input_names}, index=pd.date_range("2020", periods=test_index_len), ) new_args = merge_dicts({c: test_df[c] for c in input_names}, kwargs) result = suppress_stdout(func)(**new_args) # Concatenate Series/DataFrames if the result is a tuple if isinstance(result, tuple): results = [] for i, r in enumerate(result): if len(r.index) != len(test_df.index): if not silence_warnings: warn(f"Couldn't parse the output at index {i}: mismatching index") else: results.append(r) if len(results) > 1: result = pd.concat(results, axis=1) elif len(results) == 1: result = results[0] else: raise ValueError("Couldn't parse the output") # Test if the produced array has the same index length if len(result.index) != len(test_df.index): raise ValueError("Couldn't parse the output: mismatching index") # Standardize output names: remove numbers, remove hyphens, and bring to lower case output_cols = result.columns.tolist() if isinstance(result, pd.DataFrame) else [result.name] new_output_cols = [] for i in range(len(output_cols)): name_parts = [] for name_part in output_cols[i].split("_"): try: float(name_part) continue except: name_parts.append(name_part.replace("-", "_").lower()) output_col = "_".join(name_parts) new_output_cols.append(output_col) # Add numbers to duplicates for k, v in Counter(new_output_cols).items(): if v == 1: output_names.append(k) else: for i in range(v): output_names.append(k + str(i)) return dict( class_name=func.__name__.upper(), class_docstring=func.__doc__, input_names=input_names, param_names=param_names, output_names=output_names, defaults=defaults, ) @classmethod def list_pandas_ta_indicators(cls, silence_warnings: bool = True, **kwargs) -> tp.List[str]: """List all parseable indicators in `pandas_ta`. !!! note Returns only the indicators that have been successfully parsed.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pandas_ta") import pandas_ta indicators = set() for func_name in [_k for k, v in pandas_ta.Category.items() for _k in v]: try: cls.parse_pandas_ta_config(getattr(pandas_ta, func_name), silence_warnings=silence_warnings, **kwargs) indicators.add(func_name.upper()) except Exception as e: if not silence_warnings: warn(f"Function {func_name}: " + str(e)) return sorted(indicators) @classmethod def from_pandas_ta( cls, func_name: str, parse_kwargs: tp.KwargsLike = None, factory_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Type[IndicatorBase]: """Build an indicator class around a `pandas_ta` function. Requires [pandas-ta](https://github.com/twopirllc/pandas-ta) installed. Args: func_name (str): Function name. parse_kwargs (dict): Keyword arguments passed to `IndicatorFactory.parse_pandas_ta_config`. factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. **kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`. Returns: Indicator Usage: ```pycon >>> SMA = vbt.IF.from_pandas_ta('SMA') >>> sma = SMA.run(price, length=[2, 3]) >>> sma.sma sma_length 2 3 a b a b 2020-01-01 NaN NaN NaN NaN 2020-01-02 1.5 4.5 NaN NaN 2020-01-03 2.5 3.5 2.0 4.0 2020-01-04 3.5 2.5 3.0 3.0 2020-01-05 4.5 1.5 4.0 2.0 ``` * To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`: ```pycon >>> vbt.phelp(SMA.run) SMA.run( close, length=Default(value=None), talib=Default(value=None), offset=Default(value=None), short_name='sma', hide_params=None, hide_default=True, **kwargs ): Run `SMA` indicator. * Inputs: `close` * Parameters: `length`, `talib`, `offset` * Outputs: `sma` Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all. Set `hide_default` to False to show the column levels of the parameters with a default value. Other keyword arguments are passed to `SMA.run_pipeline`. ``` * To get the indicator docstring, use the `help` command or print the `__doc__` attribute: ```pycon >>> print(SMA.__doc__) Simple Moving Average (SMA) The Simple Moving Average is the classic moving average that is the equally weighted average over n periods. Sources: https://www.tradingtechnologies.com/help/x-study/technical-indicator-definitions/simple-moving-average-sma/ Calculation: Default Inputs: length=10 SMA = SUM(close, length) / length Args: close (pd.Series): Series of 'close's length (int): It's period. Default: 10 offset (int): How many periods to offset the result. Default: 0 Kwargs: adjust (bool): Default: True presma (bool, optional): If True, uses SMA for initial value. fillna (value, optional): pd.DataFrame.fillna(value) fill_method (value, optional): Type of fill method Returns: pd.Series: New feature generated. ``` """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pandas_ta") import pandas_ta func_name = func_name.lower() func = getattr(pandas_ta, func_name) if parse_kwargs is None: parse_kwargs = {} config = cls.parse_pandas_ta_config(func, **parse_kwargs) def apply_func( input_tuple: tp.Tuple[tp.AnyArray, ...], in_output_tuple: tp.Tuple[tp.SeriesFrame, ...], param_tuple: tp.Tuple[tp.ParamValue, ...], **_kwargs, ) -> tp.MaybeTuple[tp.Array2d]: is_series = isinstance(input_tuple[0], pd.Series) n_input_cols = 1 if is_series else len(input_tuple[0].columns) outputs = [] for col in range(n_input_cols): output = suppress_stdout(func)( **{ name: input_tuple[i] if is_series else input_tuple[i].iloc[:, col] for i, name in enumerate(config["input_names"]) }, **{name: param_tuple[i] for i, name in enumerate(config["param_names"])}, **_kwargs, ) if isinstance(output, tuple): _outputs = [] for o in output: if len(input_tuple[0].index) == len(o.index): _outputs.append(o) if len(_outputs) > 1: output = pd.concat(_outputs, axis=1) elif len(_outputs) == 1: output = _outputs[0] else: raise ValueError("No valid outputs were returned") if isinstance(output, pd.DataFrame): output = tuple([output.iloc[:, i] for i in range(len(output.columns))]) outputs.append(output) if isinstance(outputs[0], tuple): # multiple outputs outputs = list(zip(*outputs)) return tuple(map(column_stack_arrays, outputs)) return column_stack_arrays(outputs) kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs) Indicator = cls( **merge_dicts(dict(module_name=__name__ + ".pandas_ta"), config, factory_kwargs), ).with_apply_func(apply_func, pass_packed=True, keep_pd=True, to_2d=False, **kwargs) return Indicator @classmethod def list_ta_indicators(cls, uppercase: bool = False) -> tp.List[str]: """List all parseable indicators in `ta`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("ta") import ta ta_module_names = [k for k in dir(ta) if isinstance(getattr(ta, k), ModuleType)] indicators = set() for module_name in ta_module_names: module = getattr(ta, module_name) for name in dir(module): obj = getattr(module, name) if ( isinstance(obj, type) and obj != ta.utils.IndicatorMixin and issubclass(obj, ta.utils.IndicatorMixin) ): if uppercase: indicators.add(obj.__name__.upper()) else: indicators.add(obj.__name__) return sorted(indicators) @classmethod def find_ta_indicator(cls, cls_name: str) -> IndicatorMixinT: """Get `ta` indicator class by its name.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("ta") import ta ta_module_names = [k for k in dir(ta) if isinstance(getattr(ta, k), ModuleType)] for module_name in ta_module_names: module = getattr(ta, module_name) for attr in dir(module): if cls_name.upper() == attr.upper(): return getattr(module, attr) raise AttributeError(f"Indicator '{cls_name}' not found") @classmethod def parse_ta_config(cls, ind_cls: IndicatorMixinT) -> tp.Kwargs: """Parse the config of a `ta` indicator.""" input_names = [] param_names = [] defaults = {} output_names = [] # Parse the __init__ signature of the indicator class to get input names sig = inspect.signature(ind_cls) for k, v in sig.parameters.items(): if v.kind not in (v.VAR_POSITIONAL, v.VAR_KEYWORD): if v.annotation == inspect.Parameter.empty: raise ValueError(f'Argument "{k}" has no annotation') if v.annotation == pd.Series: input_names.append(k) else: param_names.append(k) if v.default != inspect.Parameter.empty: defaults[k] = v.default # Get output names by looking into instance methods for attr in dir(ind_cls): if not attr.startswith("_"): if inspect.signature(getattr(ind_cls, attr)).return_annotation == pd.Series: output_names.append(attr) elif "Returns:\n pandas.Series" in getattr(ind_cls, attr).__doc__: output_names.append(attr) return dict( class_name=ind_cls.__name__, class_docstring=ind_cls.__doc__, input_names=input_names, param_names=param_names, output_names=output_names, defaults=defaults, ) @classmethod def from_ta(cls, cls_name: str, factory_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]: """Build an indicator class around a `ta` class. Requires [ta](https://github.com/bukosabino/ta) installed. Args: cls_name (str): Class name. factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. **kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`. Returns: Indicator Usage: ```pycon >>> SMAIndicator = vbt.IF.from_ta('SMAIndicator') >>> sma = SMAIndicator.run(price, window=[2, 3]) >>> sma.sma_indicator smaindicator_window 2 3 a b a b 2020-01-01 NaN NaN NaN NaN 2020-01-02 1.5 4.5 NaN NaN 2020-01-03 2.5 3.5 2.0 4.0 2020-01-04 3.5 2.5 3.0 3.0 2020-01-05 4.5 1.5 4.0 2.0 ``` * To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`: ```pycon >>> vbt.phelp(SMAIndicator.run) SMAIndicator.run( close, window, fillna=Default(value=False), short_name='smaindicator', hide_params=None, hide_default=True, **kwargs ): Run `SMAIndicator` indicator. * Inputs: `close` * Parameters: `window`, `fillna` * Outputs: `sma_indicator` Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all. Set `hide_default` to False to show the column levels of the parameters with a default value. Other keyword arguments are passed to `SMAIndicator.run_pipeline`. ``` * To get the indicator docstring, use the `help` command or print the `__doc__` attribute: ```pycon >>> print(SMAIndicator.__doc__) SMA - Simple Moving Average Args: close(pandas.Series): dataset 'Close' column. window(int): n period. fillna(bool): if True, fill nan values. ``` """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("ta") ind_cls = cls.find_ta_indicator(cls_name) config = cls.parse_ta_config(ind_cls) def apply_func( input_tuple: tp.Tuple[tp.AnyArray, ...], in_output_tuple: tp.Tuple[tp.SeriesFrame, ...], param_tuple: tp.Tuple[tp.ParamValue, ...], **_kwargs, ) -> tp.MaybeTuple[tp.Array2d]: is_series = isinstance(input_tuple[0], pd.Series) n_input_cols = 1 if is_series else len(input_tuple[0].columns) outputs = [] for col in range(n_input_cols): ind = ind_cls( **{ name: input_tuple[i] if is_series else input_tuple[i].iloc[:, col] for i, name in enumerate(config["input_names"]) }, **{name: param_tuple[i] for i, name in enumerate(config["param_names"])}, **_kwargs, ) output = [] for output_name in config["output_names"]: output.append(getattr(ind, output_name)()) if len(output) == 1: output = output[0] else: output = tuple(output) outputs.append(output) if isinstance(outputs[0], tuple): # multiple outputs outputs = list(zip(*outputs)) return tuple(map(column_stack_arrays, outputs)) return column_stack_arrays(outputs) kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs) Indicator = cls(**merge_dicts(dict(module_name=__name__ + ".ta"), config, factory_kwargs)).with_apply_func( apply_func, pass_packed=True, keep_pd=True, to_2d=False, **kwargs, ) return Indicator @classmethod def parse_technical_config(cls, func: tp.Callable, test_index_len: int = 100) -> tp.Kwargs: """Parse the config of a `technical` indicator.""" df = pd.DataFrame( np.random.randint(1, 10, size=(test_index_len, 5)), index=pd.date_range("2020", periods=test_index_len), columns=["open", "high", "low", "close", "volume"], ) func_arg_names = get_func_arg_names(func) func_kwargs = get_func_kwargs(func) args = () input_names = [] param_names = [] output_names = [] defaults = {} for arg_name in func_arg_names: if arg_name == "field": continue if arg_name in ("dataframe", "df", "bars"): args += (df,) if "field" in func_kwargs: input_names.append(func_kwargs["field"]) else: input_names.extend(["open", "high", "low", "close", "volume"]) elif arg_name in ("series", "sr"): args += (df["close"],) input_names.append("close") elif arg_name in ("open", "high", "low", "close", "volume"): args += (df["close"],) input_names.append(arg_name) else: if arg_name not in func_kwargs: args += (5,) else: defaults[arg_name] = func_kwargs[arg_name] param_names.append(arg_name) if len(input_names) == 0: raise ValueError("Couldn't parse the output: unknown input arguments") def _validate_series(sr, name: tp.Optional[str] = None): if not isinstance(sr, pd.Series): raise TypeError("Couldn't parse the output: wrong output type") if len(sr.index) != len(df.index): raise ValueError("Couldn't parse the output: mismatching index") if np.issubdtype(sr.dtype, object): raise ValueError("Couldn't parse the output: wrong output data type") if name is None and sr.name is None: raise ValueError("Couldn't parse the output: missing output name") out = suppress_stdout(func)(*args) if isinstance(out, list): out = np.asarray(out) if isinstance(out, np.ndarray): out = pd.Series(out) if isinstance(out, dict): out = pd.DataFrame(out) if isinstance(out, tuple): out = pd.concat(out, axis=1) if isinstance(out, (pd.Series, pd.DataFrame)): if isinstance(out, pd.DataFrame): for c in out.columns: _validate_series(out[c], name=c) output_names.append(c) else: if out.name is not None: out_name = out.name else: out_name = func.__name__.lower() _validate_series(out, name=out_name) output_names.append(out_name) else: raise TypeError("Couldn't parse the output: wrong output type") new_output_names = [] for name in output_names: name = name.replace(" ", "").lower() if len(output_names) == 1 and name == "close": new_output_names.append(func.__name__.lower()) continue if name in ("open", "high", "low", "close", "volume", "data"): continue new_output_names.append(name) return dict( class_name=func.__name__.upper(), class_docstring=func.__doc__, input_names=input_names, param_names=param_names, output_names=new_output_names, defaults=defaults, ) @classmethod def list_technical_indicators(cls, silence_warnings: bool = True, **kwargs) -> tp.List[str]: """List all parseable indicators in `technical`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("technical") import technical match_func = lambda k, v: isinstance(v, FunctionType) funcs = search_package(technical, match_func, blacklist=["technical.util"]) indicators = set() for func_name, func in funcs.items(): try: cls.parse_technical_config(func, **kwargs) indicators.add(func_name.upper()) except Exception as e: if not silence_warnings: warn(f"Function {func_name}: " + str(e)) return sorted(indicators) @classmethod def find_technical_indicator(cls, func_name: str) -> IndicatorMixinT: """Get `technical` indicator function by its name.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("technical") import technical match_func = lambda k, v: isinstance(v, FunctionType) funcs = search_package(technical, match_func, blacklist=["technical.util"]) for k, v in funcs.items(): if func_name.upper() == k.upper(): return v raise AttributeError(f"Indicator '{func_name}' not found") @classmethod def from_technical( cls, func_name: str, parse_kwargs: tp.KwargsLike = None, factory_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Type[IndicatorBase]: """Build an indicator class around a `technical` function. Requires [technical](https://github.com/freqtrade/technical) installed. Args: func_name (str): Function name. parse_kwargs (dict): Keyword arguments passed to `IndicatorFactory.parse_technical_config`. factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. **kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`. Returns: Indicator Usage: ```pycon >>> ROLLING_MEAN = vbt.IF.from_technical("ROLLING_MEAN") >>> rolling_mean = ROLLING_MEAN.run(price, window=[3, 4]) >>> rolling_mean.rolling_mean rolling_mean_window 3 4 a b a b 2020-01-01 NaN NaN NaN NaN 2020-01-02 NaN NaN NaN NaN 2020-01-03 2.0 4.0 NaN NaN 2020-01-04 3.0 3.0 2.5 3.5 2020-01-05 4.0 2.0 3.5 2.5 ``` * To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`: ```pycon >>> vbt.phelp(ROLLING_MEAN.run) ROLLING_MEAN.run( close, window=Default(value=200), min_periods=Default(value=None), short_name='rolling_mean', hide_params=None, hide_default=True, **kwargs ): Run `ROLLING_MEAN` indicator. * Inputs: `close` * Parameters: `window`, `min_periods` * Outputs: `rolling_mean` Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all. Set `hide_default` to False to show the column levels of the parameters with a default value. Other keyword arguments are passed to `ROLLING_MEAN.run_pipeline`. ``` """ func = cls.find_technical_indicator(func_name) func_arg_names = get_func_arg_names(func) if parse_kwargs is None: parse_kwargs = {} config = cls.parse_technical_config(func, **parse_kwargs) def apply_func( input_tuple: tp.Tuple[tp.Series, ...], in_output_tuple: tp.Tuple[tp.Series, ...], param_tuple: tp.Tuple[tp.ParamValue, ...], *_args, **_kwargs, ) -> tp.MaybeTuple[tp.Array1d]: input_series = {name: input_tuple[i] for i, name in enumerate(config["input_names"])} _kwargs = {**{name: param_tuple[i] for i, name in enumerate(config["param_names"])}, **_kwargs} __args = () for arg_name in func_arg_names: if arg_name in ("dataframe", "df", "bars"): __args += (pd.DataFrame(input_series),) elif arg_name in ("series", "sr"): __args += (input_series["close"],) elif arg_name in ("open", "high", "low", "close", "volume"): __args += (input_series["close"],) else: break out = suppress_stdout(func)(*__args, *_args, **_kwargs) if isinstance(out, list): out = np.asarray(out) if isinstance(out, np.ndarray): out = pd.Series(out) if isinstance(out, dict): out = pd.DataFrame(out) if isinstance(out, tuple): out = pd.concat(out, axis=1) if isinstance(out, pd.DataFrame): outputs = [] for c in out.columns: if len(out.columns) == len(config["output_names"]): outputs.append(out[c].values) elif c.replace(" ", "").lower() not in ("open", "high", "low", "close", "volume", "data"): outputs.append(out[c].values) return tuple(outputs) return out.values kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs) Indicator = cls( **merge_dicts(dict(module_name=__name__ + ".technical"), config, factory_kwargs), ).with_apply_func(apply_func, pass_packed=True, keep_pd=True, takes_1d=True, **kwargs) return Indicator @classmethod def from_custom_techcon( cls, consensus_cls: tp.Type[ConsensusT], factory_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Type[IndicatorBase]: """Create an indicator based on a technical consensus class subclassing `technical.consensus.consensus.Consensus`. Requires Technical library: https://github.com/freqtrade/technical""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("technical") from technical.consensus.consensus import Consensus checks.assert_subclass_of(consensus_cls, Consensus) def apply_func( open: tp.Series, high: tp.Series, low: tp.Series, close: tp.Series, volume: tp.Series, smooth: tp.Optional[int] = None, _consensus_cls: tp.Type[ConsensusT] = consensus_cls, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Apply function for `technical.consensus.movingaverage.MovingAverageConsensus`.""" dataframe = pd.DataFrame( { "open": open, "high": high, "low": low, "close": close, "volume": volume, } ) consensus = _consensus_cls(dataframe) score = consensus.score(smooth=smooth) return ( score["buy"].values, score["sell"].values, score["buy_agreement"].values, score["sell_agreement"].values, score["buy_disagreement"].values, score["sell_disagreement"].values, ) if factory_kwargs is None: factory_kwargs = {} factory_kwargs = merge_dicts( dict( class_name="CON", module_name=__name__ + ".custom_techcon", short_name=None, input_names=["open", "high", "low", "close", "volume"], param_names=["smooth"], output_names=[ "buy", "sell", "buy_agreement", "sell_agreement", "buy_disagreement", "sell_disagreement", ], ), factory_kwargs, ) Indicator = cls(**factory_kwargs).with_apply_func( apply_func, takes_1d=True, keep_pd=True, smooth=None, **kwargs, ) def plot( self, column: tp.Optional[tp.Label] = None, buy_trace_kwargs: tp.KwargsLike = None, sell_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `MA.ma` against `MA.close`. Args: column (str): Name of the column to plot. buy_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `buy`. sell_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `sell`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if buy_trace_kwargs is None: buy_trace_kwargs = {} if sell_trace_kwargs is None: sell_trace_kwargs = {} buy_trace_kwargs = merge_dicts( dict(name="Buy", line=dict(color=plotting_cfg["color_schema"]["green"])), buy_trace_kwargs, ) sell_trace_kwargs = merge_dicts( dict(name="Sell", line=dict(color=plotting_cfg["color_schema"]["red"])), sell_trace_kwargs, ) fig = self_col.buy.vbt.lineplot( trace_kwargs=buy_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.sell.vbt.lineplot( trace_kwargs=sell_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig Indicator.plot = plot return Indicator @classmethod def from_techcon(cls, cls_name: str, **kwargs) -> tp.Type[IndicatorBase]: """Create an indicator from a preset technical consensus. Supported are case-insensitive values `MACON` (or `MovingAverageConsensus`), `OSCCON` (or `OscillatorConsensus`), and `SUMCON` (or `SummaryConsensus`).""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("technical") if cls_name.lower() in ("MACON".lower(), "MovingAverageConsensus".lower()): from technical.consensus.movingaverage import MovingAverageConsensus return cls.from_custom_techcon( MovingAverageConsensus, factory_kwargs=dict(module_name=__name__ + ".techcon", class_name="MACON"), **kwargs, ) if cls_name.lower() in ("OSCCON".lower(), "OscillatorConsensus".lower()): from technical.consensus.oscillator import OscillatorConsensus return cls.from_custom_techcon( OscillatorConsensus, factory_kwargs=dict(module_name=__name__ + ".techcon", class_name="OSCCON"), **kwargs, ) if cls_name.lower() in ("SUMCON".lower(), "SummaryConsensus".lower()): from technical.consensus.summary import SummaryConsensus return cls.from_custom_techcon( SummaryConsensus, factory_kwargs=dict(module_name=__name__ + ".techcon", class_name="SUMCON"), **kwargs, ) raise ValueError(f"Unknown technical consensus class '{cls_name}'") @classmethod def list_techcon_indicators(cls) -> tp.List[str]: """List all consensus indicators in `technical`.""" return sorted({"MACON", "OSCCON", "SUMCON"}) @classmethod def find_smc_indicator(cls, func_name: str, raise_error: bool = True) -> tp.Optional[tp.Callable]: """Get `smartmoneyconcepts` indicator class by its name.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("smartmoneyconcepts") from smartmoneyconcepts import smc for k in dir(smc): if not k.startswith("_"): if camel_to_snake_case(func_name) == camel_to_snake_case(k): return getattr(smc, k) if raise_error: raise AttributeError(f"Indicator '{func_name}' not found") return None @classmethod def parse_smc_config(cls, func: tp.Callable, collapse: bool = True, snake_case: bool = True) -> tp.Kwargs: """Parse the config of a `smartmoneyconcepts` indicator.""" func_arg_names = get_func_arg_names(func) input_names = [] param_names = [] defaults = {} dep_input_names = {} sig = inspect.signature(func) for k in func_arg_names: if k == "ohlc": input_names.extend(["open", "high", "low", "close", "volume"]) else: found_smc_indicator = cls.find_smc_indicator(k, raise_error=False) if found_smc_indicator is not None: dep_input_names[k] = [] k_func_config = cls.parse_smc_config(found_smc_indicator) if collapse: for input_name in k_func_config["input_names"]: if input_name not in input_names: input_names.append(input_name) for param_name in k_func_config["param_names"]: if param_name not in param_names: param_names.append(param_name) for k2, v2 in k_func_config["defaults"].items(): defaults[k2] = v2 else: for output_name in k_func_config["output_names"]: if output_name not in input_names: input_names.append(output_name) dep_input_names[k].append(output_name) else: v = sig.parameters[k] if v.kind not in (v.VAR_POSITIONAL, v.VAR_KEYWORD): if v.default == inspect.Parameter.empty and v.annotation == pd.DataFrame: if k not in input_names: input_names.append(k) else: if k not in param_names: param_names.append(k) if v.default != inspect.Parameter.empty: defaults[k] = v.default func_doc = inspect.getsource(func) output_names = re.findall(r'name="([^"]+)"', func_doc) output_names = [k.replace("%", "") for k in output_names] if snake_case: input_names = list(map(camel_to_snake_case, input_names)) param_names = list(map(camel_to_snake_case, param_names)) output_names = list(map(camel_to_snake_case, output_names)) return dict( class_name=func.__name__.upper(), class_docstring=func.__doc__, input_names=input_names, param_names=param_names, output_names=output_names, defaults=defaults, dep_input_names=dep_input_names, ) @classmethod def list_smc_indicators(cls, silence_warnings: bool = True, **kwargs) -> tp.List[str]: """List all parseable indicators in `smartmoneyconcepts`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("smartmoneyconcepts") from smartmoneyconcepts import smc indicators = set() for func_name in dir(smc): if not func_name.startswith("_"): try: cls.parse_smc_config(getattr(smc, func_name), **kwargs) indicators.add(func_name.upper()) except Exception as e: if not silence_warnings: warn(f"Function {func_name}: " + str(e)) return sorted(indicators) @classmethod def from_smc( cls, func_name: str, collapse: bool = True, parse_kwargs: tp.KwargsLike = None, factory_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Type[IndicatorBase]: """Build an indicator class around a `smartmoneyconcepts` function. Requires [smart-money-concepts](https://github.com/joshyattridge/smart-money-concepts) installed. Args: func_name (str): Function name. collapse (bool): Whether to collapse all nested indicators into a single one. parse_kwargs (dict): Keyword arguments passed to `IndicatorFactory.parse_smc_config`. factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. **kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`. """ func = cls.find_smc_indicator(func_name) func_arg_names = get_func_arg_names(func) if parse_kwargs is None: parse_kwargs = {} collapsed_config = cls.parse_smc_config(func, collapse=True, snake_case=True, **parse_kwargs) _ = collapsed_config.pop("dep_input_names") expanded_config = cls.parse_smc_config(func, collapse=False, snake_case=True, **parse_kwargs) dep_input_names = expanded_config.pop("dep_input_names") if collapse: config = collapsed_config else: config = expanded_config def apply_func( input_tuple: tp.Tuple[tp.Series, ...], in_output_tuple: tp.Tuple[tp.Series, ...], param_tuple: tp.Tuple[tp.ParamValue, ...], **_kwargs, ) -> tp.MaybeTuple[tp.Array1d]: named_args = dict(_kwargs) for i, input_name in enumerate(config["input_names"]): named_args[input_name] = input_tuple[i] for i, param_name in enumerate(config["param_names"]): named_args[param_name] = param_tuple[i] named_args["ohlc"] = pd.concat( [ named_args["open"].rename("open"), named_args["high"].rename("high"), named_args["low"].rename("low"), named_args["close"].rename("close"), named_args["volume"].rename("volume"), ], axis=1, ) if collapse and len(dep_input_names) > 0: for dep_func_name in dep_input_names: dep_func = cls.find_smc_indicator(dep_func_name) dep_func_arg_names = get_func_arg_names(dep_func) dep_output = dep_func(*[named_args[camel_to_snake_case(k)] for k in dep_func_arg_names]) dep_output.index = input_tuple[0].index named_args[dep_func_name] = dep_output elif not collapse: for dep_func_name in dep_input_names: dep_func = cls.find_smc_indicator(dep_func_name) dep_config = cls.parse_smc_config(dep_func, collapse=False, snake_case=False, **parse_kwargs) named_args[dep_func_name] = pd.concat( [named_args[input_name] for input_name in dep_input_names[dep_func_name]], axis=1, keys=dep_config["output_names"], ) output = func(*[named_args[camel_to_snake_case(k)] for k in func_arg_names]) return tuple([output[c] for c in output.columns]) kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs) Indicator = cls( **merge_dicts(dict(module_name=__name__ + ".smc"), config, factory_kwargs), ).with_apply_func(apply_func, pass_packed=True, keep_pd=True, takes_1d=True, **kwargs) return Indicator # ############# Expressions ############# # @hybrid_method def from_expr( cls_or_self, expr: str, parse_annotations: bool = True, factory_kwargs: tp.KwargsLike = None, magnet_inputs: tp.Iterable[str] = None, magnet_in_outputs: tp.Iterable[str] = None, magnet_params: tp.Iterable[str] = None, func_mapping: tp.KwargsLike = None, res_func_mapping: tp.KwargsLike = None, use_pd_eval: tp.Optional[bool] = None, pd_eval_kwargs: tp.KwargsLike = None, return_clean_expr: bool = False, **kwargs, ) -> tp.Union[str, tp.Type[IndicatorBase]]: """Build an indicator class from an indicator expression. Args: expr (str): Expression. Expression must be a string with a valid Python code. Supported are both single-line and multi-line expressions. parse_annotations (bool): Whether to parse annotations starting with `@`. factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`. Only applied when calling the class method. magnet_inputs (iterable of str): Names recognized as input names. Defaults to `open`, `high`, `low`, `close`, and `volume`. magnet_in_outputs (iterable of str): Names recognized as in-output names. Defaults to an empty list. magnet_params (iterable of str): Names recognized as params names. Defaults to an empty list. func_mapping (mapping): Mapping merged over `vectorbtpro.indicators.expr.expr_func_config`. Each key must be a function name and each value must be a dict with `func` and optionally `magnet_inputs`, `magnet_in_outputs`, and `magnet_params`. res_func_mapping (mapping): Mapping merged over `vectorbtpro.indicators.expr.expr_res_func_config`. Each key must be a function name and each value must be a dict with `func` and optionally `magnet_inputs`, `magnet_in_outputs`, and `magnet_params`. use_pd_eval (bool): Whether to use `pd.eval`. Defaults to False. Otherwise, uses `vectorbtpro.utils.eval_.evaluate`. !!! hint By default, operates on NumPy objects using NumExpr. If you want to operate on Pandas objects, set `keep_pd` to True. pd_eval_kwargs (dict): Keyword arguments passed to `pd.eval`. return_clean_expr (bool): Whether to return a cleaned expression. **kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`. Returns: Indicator Searches each variable name parsed from `expr` in * `vectorbtpro.indicators.expr.expr_res_func_config` (calls right away) * `vectorbtpro.indicators.expr.expr_func_config` * inputs, in-outputs, and params * keyword arguments * attributes of `np` * attributes of `vectorbtpro.generic.nb` (with and without `_nb` suffix) * attributes of `vbt` `vectorbtpro.indicators.expr.expr_func_config` and `vectorbtpro.indicators.expr.expr_res_func_config` can be overridden with `func_mapping` and `res_func_mapping` respectively. !!! note Each variable name is case-sensitive. When using the class method, all names are parsed from the expression itself. If any of `open`, `high`, `low`, `close`, and `volume` appear in the expression or in `magnet_inputs` in either `vectorbtpro.indicators.expr.expr_func_config` or `vectorbtpro.indicators.expr.expr_res_func_config`, they are automatically added to `input_names`. Set `magnet_inputs` to an empty list to disable this logic. If the expression begins with a valid variable name and a colon (`:`), the variable name will be used as the name of the generated class. Provide another variable in the square brackets after this one and before the colon to specify the indicator's short name. If `parse_annotations` is True, variables that start with `@` have a special meaning: * `@in_*`: input variable * `@inout_*`: in-output variable * `@p_*`: parameter variable * `@out_*`: output variable * `@out_*:`: indicates that the next part until a comma is an output * `@talib_*`: name of a TA-Lib function. Uses the indicator's `apply_func`. * `@res_*`: name of the indicator to resolve automatically. Input names can overlap with those of other indicators, while all other information gets a prefix with the indicator's short name. * `@settings(*)`: settings to be merged with the current `IndicatorFactory.from_expr` settings. Everything within the parentheses gets evaluated using the Pythons `eval` command and must be a dictionary. Overrides defaults but gets overridden by any argument passed to this method. Arguments `expr` and `parse_annotations` cannot be overridden. !!! note The parsed names come in the same order they appear in the expression, not in the execution order, apart from the magnet input names, which are added in the same order they appear in the list. The number of outputs is derived based on the number of commas outside of any bracket pair. If there is only one output, the output name is `out`. If more - `out1`, `out2`, etc. Any information can be overridden using `factory_kwargs`. Usage: ```pycon >>> WMA = vbt.IF( ... class_name='WMA', ... input_names=['close'], ... param_names=['window'], ... output_names=['wma'] ... ).from_expr("wm_mean_nb(close, window)") >>> wma = WMA.run(price, window=[2, 3]) >>> wma.wma wma_window 2 3 a b a b 2020-01-01 NaN NaN NaN NaN 2020-01-02 1.666667 4.333333 NaN NaN 2020-01-03 2.666667 3.333333 2.333333 3.666667 2020-01-04 3.666667 2.333333 3.333333 2.666667 2020-01-05 4.666667 1.333333 4.333333 1.666667 ``` * The same can be achieved by calling the class method and providing prefixes to the variable names to indicate their type: ```pycon >>> expr = "WMA: @out_wma:wm_mean_nb((@in_high + @in_low) / 2, @p_window)" >>> WMA = vbt.IF.from_expr(expr) >>> wma = WMA.run(price + 1, price, window=[2, 3]) >>> wma.wma wma_window 2 3 a b a b 2020-01-01 NaN NaN NaN NaN 2020-01-02 2.166667 4.833333 NaN NaN 2020-01-03 3.166667 3.833333 2.833333 4.166667 2020-01-04 4.166667 2.833333 3.833333 3.166667 2020-01-05 5.166667 1.833333 4.833333 2.166667 ``` * Magnet names are recognized automatically: ```pycon >>> expr = "WMA: @out_wma:wm_mean_nb((high + low) / 2, @p_window)" ``` * Most settings of this method can be overriden from within the expression: ```pycon >>> expr = \"\"\" ... @settings({factory_kwargs={'class_name': 'WMA', 'param_names': ['window']}}) ... @out_wma:wm_mean_nb((high + low) / 2, window) ... \"\"\" ``` """ def _clean_expr(expr: str) -> str: # Clean the expression from redundant brackets and commas expr = inspect.cleandoc(expr).strip() if expr.endswith(","): expr = expr[:-1] if expr.startswith("(") and expr.endswith(")"): n_open_brackets = 0 remove_brackets = True for i, s in enumerate(expr): if s == "(": n_open_brackets += 1 elif s == ")": n_open_brackets -= 1 if n_open_brackets == 0 and i < len(expr) - 1: remove_brackets = False break if remove_brackets: expr = expr[1:-1] if expr.endswith(","): expr = expr[:-1] # again return expr if isinstance(cls_or_self, type): settings = dict( factory_kwargs=dict( class_name=None, input_names=[], in_output_names=[], param_names=[], output_names=[], ) ) # Parse the class name match = re.match(r"^\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\[([a-zA-Z_][a-zA-Z0-9_]*)\])?\s*:\s*", expr) if match: settings["factory_kwargs"]["class_name"] = match.group(1) if match.group(2): settings["factory_kwargs"]["short_name"] = match.group(2) expr = expr[len(match.group(0)) :] # Parse the settings dictionary if "@settings" in expr: remove_chars = set() for m in re.finditer("@settings", expr): n_open_brackets = 0 from_i = None to_i = None for i in range(m.start(), m.end()): remove_chars.add(i) for i in range(m.end(), len(expr)): remove_chars.add(i) s = expr[i] if s in "(": if n_open_brackets == 0: from_i = i + 1 n_open_brackets += 1 elif s in ")": n_open_brackets -= 1 if n_open_brackets == 0: to_i = i break if n_open_brackets != 0: raise ValueError("Couldn't parse the settings: mismatching brackets") settings = merge_dicts(settings, eval(_clean_expr(expr[from_i:to_i]))) expr = "".join([expr[i] for i in range(len(expr)) if i not in remove_chars]) expr = _clean_expr(expr) # Merge info parsed_factory_kwargs = settings.pop("factory_kwargs") magnet_inputs = settings.pop("magnet_inputs", magnet_inputs) magnet_in_outputs = settings.pop("magnet_in_outputs", magnet_in_outputs) magnet_params = settings.pop("magnet_params", magnet_params) func_mapping = merge_dicts(expr_func_config, settings.pop("func_mapping", None), func_mapping) res_func_mapping = merge_dicts( expr_res_func_config, settings.pop("res_func_mapping", None), res_func_mapping, ) use_pd_eval = settings.pop("use_pd_eval", use_pd_eval) pd_eval_kwargs = merge_dicts(settings.pop("pd_eval_kwargs", None), pd_eval_kwargs) # Resolve defaults if use_pd_eval is None: use_pd_eval = False if magnet_inputs is None: magnet_inputs = ["open", "high", "low", "close", "volume"] if magnet_in_outputs is None: magnet_in_outputs = [] if magnet_params is None: magnet_params = [] found_magnet_inputs = [] found_magnet_in_outputs = [] found_magnet_params = [] found_defaults = {} remove_defaults = set() # Parse annotated variables if parse_annotations: # Parse input, in-output, parameter, and TA-Lib function names for var_name in re.findall(r"@[a-z]+_[a-zA-Z_][a-zA-Z0-9_]*", expr): var_name = var_name.replace("@", "") if var_name.startswith("in_"): var_name = var_name[3:] if var_name in magnet_inputs: if var_name not in found_magnet_inputs: found_magnet_inputs.append(var_name) else: if var_name not in parsed_factory_kwargs["input_names"]: parsed_factory_kwargs["input_names"].append(var_name) elif var_name.startswith("inout_"): var_name = var_name[6:] if var_name in magnet_in_outputs: if var_name not in found_magnet_in_outputs: found_magnet_in_outputs.append(var_name) else: if var_name not in parsed_factory_kwargs["in_output_names"]: parsed_factory_kwargs["in_output_names"].append(var_name) elif var_name.startswith("p_"): var_name = var_name[2:] if var_name in magnet_params: if var_name not in found_magnet_params: found_magnet_params.append(var_name) else: if var_name not in parsed_factory_kwargs["param_names"]: parsed_factory_kwargs["param_names"].append(var_name) elif var_name.startswith("res_"): ind_name = var_name[4:] if ind_name.startswith("talib_"): ind_name = ind_name[6:] I = cls_or_self.from_talib(ind_name) else: I = kwargs[ind_name] if not issubclass(I, IndicatorBase): raise TypeError(f"Indicator class '{ind_name}' must subclass IndicatorBase") def _ind_func(context: tp.Kwargs, _I: IndicatorBase = I) -> tp.Any: _args = () _kwargs = {} signature = inspect.signature(_I.run) for p in signature.parameters.values(): if p.name in _I.input_names: if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD): _args += (context[p.name],) else: _kwargs[p.name] = context[p.name] else: ind_p_name = _I.short_name + "_" + p.name if ind_p_name in context: if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD): _args += (context[ind_p_name],) elif p.kind == p.VAR_POSITIONAL: _args += context[ind_p_name] elif p.kind == p.VAR_KEYWORD: for k, v in context[ind_p_name].items(): _kwargs[k] = v else: _kwargs[p.name] = context[ind_p_name] return_raw = _kwargs.pop("return_raw", True) ind = _I.run(*_args, return_raw=return_raw, **_kwargs) if return_raw: raw_outputs = ind[0] if len(raw_outputs) == 1: return raw_outputs[0] return raw_outputs return ind res_func_mapping["__" + var_name] = dict( func=_ind_func, magnet_inputs=I.input_names, magnet_in_outputs=[I.short_name + "_" + name for name in I.in_output_names], magnet_params=[I.short_name + "_" + name for name in I.param_names], ) run_kwargs = get_func_kwargs(I.run) def _add_defaults(names, prefix=None): for k in names: if prefix is None: k_prefixed = k else: k_prefixed = prefix + "_" + k if k in run_kwargs: if k_prefixed in found_defaults: if not checks.is_deep_equal(found_defaults[k_prefixed], run_kwargs[k]): remove_defaults.add(k_prefixed) else: found_defaults[k_prefixed] = run_kwargs[k] _add_defaults(I.input_names) _add_defaults(I.in_output_names, I.short_name) _add_defaults(I.param_names, I.short_name) expr = expr.replace("@in_", "__in_") expr = expr.replace("@inout_", "__inout_") expr = expr.replace("@p_", "__p_") expr = expr.replace("@talib_", "__talib_") expr = expr.replace("@res_", "__res_") # Parse output names to_replace = [] for var_name in re.findall(r"@out_[a-zA-Z_][a-zA-Z0-9_]*\s*:\s*", expr): to_replace.append(var_name) var_name = var_name.split(":")[0].strip()[5:] if var_name not in parsed_factory_kwargs["output_names"]: parsed_factory_kwargs["output_names"].append(var_name) for s in to_replace: expr = expr.replace(s, "") for var_name in re.findall(r"@out_[a-zA-Z_][a-zA-Z0-9_]*", expr): var_name = var_name.replace("@", "") if var_name.startswith("out_"): var_name = var_name[4:] if var_name not in parsed_factory_kwargs["output_names"]: parsed_factory_kwargs["output_names"].append(var_name) expr = expr.replace("@out_", "__out_") if len(parsed_factory_kwargs["output_names"]) == 0: lines = expr.split("\n") if len(lines) > 1: last_line = _clean_expr(lines[-1]) valid_output_names = [] found_not_valid = False for i, out in enumerate(last_line.split(",")): out = out.strip() if not out.startswith("__") and out.isidentifier(): valid_output_names.append(out) else: found_not_valid = True break if not found_not_valid: parsed_factory_kwargs["output_names"] = valid_output_names # Parse magnet names var_names = get_expr_var_names(expr) def _find_magnets(magnet_type, magnet_names, magnet_lst, found_magnet_lst): for var_name in var_names: if var_name in magnet_lst: if var_name not in found_magnet_lst: found_magnet_lst.append(var_name) if var_name in func_mapping: for magnet_name in func_mapping[var_name].get(magnet_type, []): if magnet_name not in found_magnet_lst: found_magnet_lst.append(magnet_name) if var_name in res_func_mapping: for magnet_name in res_func_mapping[var_name].get(magnet_type, []): if magnet_name not in found_magnet_lst: found_magnet_lst.append(magnet_name) for magnet_name in magnet_lst: if magnet_name in found_magnet_lst and magnet_name not in magnet_names: magnet_names.append(magnet_name) for magnet_name in found_magnet_lst: if magnet_name not in magnet_names and magnet_name not in magnet_names: magnet_names.append(magnet_name) _find_magnets("magnet_inputs", parsed_factory_kwargs["input_names"], magnet_inputs, found_magnet_inputs) _find_magnets( "magnet_in_outputs", parsed_factory_kwargs["in_output_names"], magnet_in_outputs, found_magnet_in_outputs, ) _find_magnets("magnet_params", parsed_factory_kwargs["param_names"], magnet_params, found_magnet_params) # Prepare defaults for k in remove_defaults: found_defaults.pop(k, None) def _sort_names(names_name): new_names = [] for k in parsed_factory_kwargs[names_name]: if k not in found_defaults: new_names.append(k) for k in parsed_factory_kwargs[names_name]: if k in found_defaults: new_names.append(k) parsed_factory_kwargs[names_name] = new_names _sort_names("input_names") _sort_names("in_output_names") _sort_names("param_names") # Parse the number of outputs if len(parsed_factory_kwargs["output_names"]) == 0: lines = expr.split("\n") last_line = _clean_expr(lines[-1]) n_open_brackets = 0 n_outputs = 1 for i, s in enumerate(last_line): if s == "," and n_open_brackets == 0: n_outputs += 1 elif s in "([{": n_open_brackets += 1 elif s in ")]}": n_open_brackets -= 1 if n_open_brackets != 0: raise ValueError("Couldn't parse the number of outputs: mismatching brackets") elif len(parsed_factory_kwargs["output_names"]) == 0: if n_outputs == 1: parsed_factory_kwargs["output_names"] = ["out"] else: parsed_factory_kwargs["output_names"] = ["out%d" % (i + 1) for i in range(n_outputs)] factory = cls_or_self(**merge_dicts(parsed_factory_kwargs, factory_kwargs)) kwargs = merge_dicts(settings, found_defaults, kwargs) else: func_mapping = merge_dicts(expr_func_config, func_mapping) res_func_mapping = merge_dicts(expr_res_func_config, res_func_mapping) var_names = get_expr_var_names(expr) factory = cls_or_self if return_clean_expr: # For debugging purposes return expr input_names = factory.input_names in_output_names = factory.in_output_names param_names = factory.param_names def apply_func( input_tuple: tp.Tuple[tp.AnyArray, ...], in_output_tuple: tp.Tuple[tp.SeriesFrame, ...], param_tuple: tp.Tuple[tp.ParamValue, ...], **_kwargs, ) -> tp.MaybeTuple[tp.Array2d]: import vectorbtpro as vbt input_context = dict(np=np, pd=pd, vbt=vbt) for i, input in enumerate(input_tuple): input_context[input_names[i]] = input for i, in_output in enumerate(in_output_tuple): input_context[in_output_names[i]] = in_output for i, param in enumerate(param_tuple): input_context[param_names[i]] = param merged_context = merge_dicts(input_context, _kwargs) context = {} # Resolve each variable in the expression for var_name in var_names: if var_name in context: continue if var_name.startswith("__in_"): var = merged_context[var_name[5:]] elif var_name.startswith("__inout_"): var = merged_context[var_name[8:]] elif var_name.startswith("__p_"): var = merged_context[var_name[4:]] elif var_name.startswith("__talib_"): from vectorbtpro.indicators.talib_ import talib_func talib_func_name = var_name[8:].upper() _talib_func = talib_func(talib_func_name) var = functools.partial(_talib_func, wrapper=_kwargs["wrapper"]) elif var_name in res_func_mapping: var = res_func_mapping[var_name]["func"] elif var_name in func_mapping: var = func_mapping[var_name]["func"] elif var_name in merged_context: var = merged_context[var_name] elif hasattr(np, var_name): var = getattr(np, var_name) elif hasattr(generic_nb, var_name): var = getattr(generic_nb, var_name) elif hasattr(generic_nb, var_name + "_nb"): var = getattr(generic_nb, var_name + "_nb") elif hasattr(vbt, var_name): var = getattr(vbt, var_name) else: continue try: if callable(var) and "context" in get_func_arg_names(var): var = functools.partial(var, context=merged_context) except: pass if var_name in res_func_mapping: var = var() context[var_name] = var # Evaluate the expression using resolved variables as a context if use_pd_eval: return pd.eval(expr, local_dict=context, **resolve_dict(pd_eval_kwargs)) return evaluate(expr, context=context) return factory.with_apply_func(apply_func, pass_packed=True, pass_wrapper=True, **kwargs) @classmethod def from_wqa101(cls, alpha_idx: tp.Union[str, int], **kwargs) -> tp.Type[IndicatorBase]: """Build an indicator class from one of the WorldQuant's 101 alpha expressions. See `vectorbtpro.indicators.expr.wqa101_expr_config`. !!! note Some expressions that utilize cross-sectional operations require columns to be a multi-index with a level `sector`, `subindustry`, or `industry`. Usage: ```pycon >>> data = vbt.YFData.pull(['BTC-USD', 'ETH-USD']) >>> WQA1 = vbt.IF.from_wqa101(1) >>> wqa1 = WQA1.run(data.get('Close')) >>> wqa1.out symbol BTC-USD ETH-USD Date 2014-09-17 00:00:00+00:00 0.25 0.25 2014-09-18 00:00:00+00:00 0.25 0.25 2014-09-19 00:00:00+00:00 0.25 0.25 2014-09-20 00:00:00+00:00 0.25 0.25 2014-09-21 00:00:00+00:00 0.25 0.25 ... ... ... 2022-01-21 00:00:00+00:00 0.00 0.50 2022-01-22 00:00:00+00:00 0.00 0.50 2022-01-23 00:00:00+00:00 0.25 0.25 2022-01-24 00:00:00+00:00 0.50 0.00 2022-01-25 00:00:00+00:00 0.50 0.00 [2688 rows x 2 columns] ``` * To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`: ```pycon >>> vbt.phelp(WQA1.run) WQA1.run( close, short_name='wqa1', hide_params=None, hide_default=True, **kwargs ): Run `WQA1` indicator. * Inputs: `close` * Outputs: `out` Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all. Set `hide_default` to False to show the column levels of the parameters with a default value. Other keyword arguments are passed to `WQA1.run_pipeline`. ``` """ if isinstance(alpha_idx, str): alpha_idx = int(alpha_idx.upper().replace("WQA", "")) return cls.from_expr( wqa101_expr_config[alpha_idx], factory_kwargs=dict(class_name="WQA%d" % alpha_idx, module_name=__name__ + ".wqa101"), **kwargs, ) @classmethod def list_wqa101_indicators(cls) -> tp.List[str]: """List all WorldQuant's 101 alpha indicators.""" return [str(i) for i in range(1, 102)] IF = IndicatorFactory """Shortcut for `IndicatorFactory`.""" __pdoc__["IF"] = False def indicator(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.get_indicator`.""" return IndicatorFactory.get_indicator(*args, **kwargs) def talib(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_talib`.""" return IndicatorFactory.from_talib(*args, **kwargs) def pandas_ta(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_pandas_ta`.""" return IndicatorFactory.from_pandas_ta(*args, **kwargs) def ta(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_ta`.""" return IndicatorFactory.from_ta(*args, **kwargs) def wqa101(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_wqa101`.""" return IndicatorFactory.from_wqa101(*args, **kwargs) def technical(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_technical`.""" return IndicatorFactory.from_technical(*args, **kwargs) def techcon(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_techcon`.""" return IndicatorFactory.from_techcon(*args, **kwargs) def smc(*args, **kwargs) -> tp.Type[IndicatorBase]: """Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_smc`.""" return IndicatorFactory.from_smc(*args, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for custom indicators. Provides an arsenal of Numba-compiled functions that are used by indicator classes. These only accept NumPy arrays and other Numba-compatible types.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_nb, flex_select_col_nb from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.generic import nb as generic_nb, enums as generic_enums from vectorbtpro.indicators.enums import Pivot, SuperTrendAIS, SuperTrendAOS, HurstMethod from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch __all__ = [] # ############# MA ############# # @register_jitted(cache=True) def ma_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """Moving average. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" return generic_nb.ma_1d_nb(close, window, wtype=wtype, minp=minp, adjust=adjust) @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ma_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `ma_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) ma = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): ma[:, col] = ma_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ) return ma # ############# MSD ############# # @register_jitted(cache=True) def msd_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Array1d: """Moving standard deviation. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" return generic_nb.msd_1d_nb(close, window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof) @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ddof=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def msd_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Array2d: """2-dim version of `msd_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) msd = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): msd[:, col] = msd_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ddof=ddof, ) return msd # ############# BBANDS ############# # @register_jitted(cache=True) def bbands_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Simple, alpha: float = 2.0, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]: """Bollinger Bands. Returns the upper band, the middle band, and the lower band. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" ma = ma_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust) msd = msd_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof) upper = ma + alpha * msd middle = ma lower = ma - alpha * msd return upper, middle, lower @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), alpha=base_ch.FlexArraySlicer(), minp=None, adjust=None, ddof=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def bbands_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, alpha: tp.FlexArray1dLike = 2.0, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `bbands_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) alpha_ = to_1d_array_nb(np.asarray(alpha)) upper = np.empty(close.shape, dtype=float_) middle = np.empty(close.shape, dtype=float_) lower = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): upper[:, col], middle[:, col], lower[:, col] = bbands_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), alpha=flex_select_1d_nb(alpha_, col), minp=minp, adjust=adjust, ddof=ddof, ) return upper, middle, lower @register_jitted(cache=True) def bbands_percent_b_1d_nb(close: tp.Array1d, upper: tp.Array1d, lower: tp.Array1d) -> tp.Array1d: """Bollinger Bands %B.""" return (close - lower) / (upper - lower) @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), upper=ch.ArraySlicer(axis=1), lower=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def bbands_percent_b_nb(close: tp.Array2d, upper: tp.Array2d, lower: tp.Array2d) -> tp.Array2d: """2-dim version of `bbands_percent_b_1d_nb`.""" percent_b = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): percent_b[:, col] = bbands_percent_b_1d_nb(close[:, col], upper[:, col], lower[:, col]) return percent_b @register_jitted(cache=True) def bbands_bandwidth_1d_nb(upper: tp.Array1d, middle: tp.Array1d, lower: tp.Array1d) -> tp.Array1d: """Bollinger Bands Bandwidth.""" return (upper - lower) / middle @register_chunkable( size=ch.ArraySizer(arg_query="upper", axis=1), arg_take_spec=dict( upper=ch.ArraySlicer(axis=1), middle=ch.ArraySlicer(axis=1), lower=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def bbands_bandwidth_nb(upper: tp.Array2d, middle: tp.Array2d, lower: tp.Array2d) -> tp.Array2d: """2-dim version of `bbands_bandwidth_1d_nb`.""" bandwidth = np.empty(upper.shape, dtype=float_) for col in prange(upper.shape[1]): bandwidth[:, col] = bbands_bandwidth_1d_nb(upper[:, col], middle[:, col], lower[:, col]) return bandwidth # ############# RSI ############# # @register_jitted(cache=True) def avg_gain_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """Average gain.""" up_change = np.empty(close.shape, dtype=float_) for i in range(close.shape[0]): if i == 0: up_change[i] = np.nan else: change = close[i] - close[i - 1] if change < 0: up_change[i] = 0.0 else: up_change[i] = change avg_gain = ma_1d_nb(up_change, window=window, wtype=wtype, minp=minp, adjust=adjust) return avg_gain @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def avg_gain_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `avg_gain_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) avg_gain = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): avg_gain[:, col] = avg_gain_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ) return avg_gain @register_jitted(cache=True) def avg_loss_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """Average loss.""" down_change = np.empty(close.shape, dtype=float_) for i in range(close.shape[0]): if i == 0: down_change[i] = np.nan else: change = close[i] - close[i - 1] if change < 0: down_change[i] = abs(change) else: down_change[i] = 0.0 avg_loss = ma_1d_nb(down_change, window=window, wtype=wtype, minp=minp, adjust=adjust) return avg_loss @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def avg_loss_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `avg_loss_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) avg_loss = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): avg_loss[:, col] = avg_loss_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ) return avg_loss @register_jitted(cache=True) def rsi_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """RSI. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" avg_gain = avg_gain_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust) avg_loss = avg_loss_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust) return 100 * avg_gain / (avg_gain + avg_loss) @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rsi_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `rsi_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) rsi = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): rsi[:, col] = rsi_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ) return rsi # ############# STOCH ############# # @register_jitted(cache=True) def stoch_k_1d_nb( high: tp.Array1d, low: tp.Array1d, close: tp.Array1d, window: int = 14, minp: tp.Optional[int] = None, ) -> tp.Array1d: """Stochastic Oscillator %K.""" lowest_low = generic_nb.rolling_min_1d_nb(low, window, minp=minp) highest_high = generic_nb.rolling_max_1d_nb(high, window, minp=minp) stoch_k = 100 * (close - lowest_low) / (highest_high - lowest_low) return stoch_k @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), minp=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def stoch_k_nb( high: tp.Array2d, low: tp.Array2d, close: tp.Array2d, window: tp.FlexArray1dLike = 14, minp: tp.Optional[int] = None, ) -> tp.Array2d: """2-dim version of `stoch_k_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) stoch_k = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): stoch_k[:, col] = stoch_k_1d_nb( high=high[:, col], low=low[:, col], close=close[:, col], window=flex_select_1d_nb(window_, col), minp=minp, ) return stoch_k @register_jitted(cache=True) def stoch_1d_nb( high: tp.Array1d, low: tp.Array1d, close: tp.Array1d, fast_k_window: int = 14, slow_k_window: int = 3, slow_d_window: int = 3, wtype: int = generic_enums.WType.Simple, slow_k_wtype: tp.Optional[int] = None, slow_d_wtype: tp.Optional[int] = None, minp: tp.Optional[int] = None, fast_k_minp: tp.Optional[int] = None, slow_k_minp: tp.Optional[int] = None, slow_d_minp: tp.Optional[int] = None, adjust: bool = False, slow_k_adjust: tp.Optional[bool] = None, slow_d_adjust: tp.Optional[bool] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]: """Stochastic Oscillator. Returns the fast %K, the slow %K, and the slow %D. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" if slow_k_wtype is not None: slow_k_wtype_ = slow_k_wtype else: slow_k_wtype_ = wtype if slow_d_wtype is not None: slow_d_wtype_ = slow_d_wtype else: slow_d_wtype_ = wtype if fast_k_minp is not None: fast_k_minp_ = fast_k_minp else: fast_k_minp_ = minp if slow_k_minp is not None: slow_k_minp_ = slow_k_minp else: slow_k_minp_ = minp if slow_d_minp is not None: slow_d_minp_ = slow_d_minp else: slow_d_minp_ = minp if slow_k_adjust is not None: slow_k_adjust_ = slow_k_adjust else: slow_k_adjust_ = adjust if slow_d_adjust is not None: slow_d_adjust_ = slow_d_adjust else: slow_d_adjust_ = adjust fast_k = stoch_k_1d_nb(high, low, close, window=fast_k_window, minp=fast_k_minp_) slow_k = ma_1d_nb(fast_k, window=slow_k_window, wtype=slow_k_wtype_, minp=slow_k_minp_, adjust=slow_k_adjust_) slow_d = ma_1d_nb(slow_k, window=slow_d_window, wtype=slow_d_wtype_, minp=slow_d_minp_, adjust=slow_d_adjust_) return fast_k, slow_k, slow_d @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), fast_k_window=base_ch.FlexArraySlicer(), slow_k_window=base_ch.FlexArraySlicer(), slow_d_window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), slow_k_wtype=base_ch.FlexArraySlicer(), slow_d_wtype=base_ch.FlexArraySlicer(), minp=None, fast_k_minp=None, slow_k_minp=None, slow_d_minp=None, adjust=None, slow_k_adjust=None, slow_d_adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def stoch_nb( high: tp.Array2d, low: tp.Array2d, close: tp.Array2d, fast_k_window: tp.FlexArray1dLike = 14, slow_k_window: tp.FlexArray1dLike = 3, slow_d_window: tp.FlexArray1dLike = 3, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, slow_k_wtype: tp.Optional[tp.FlexArray1dLike] = None, slow_d_wtype: tp.Optional[tp.FlexArray1dLike] = None, minp: tp.Optional[int] = None, fast_k_minp: tp.Optional[int] = None, slow_k_minp: tp.Optional[int] = None, slow_d_minp: tp.Optional[int] = None, adjust: bool = False, slow_k_adjust: tp.Optional[bool] = None, slow_d_adjust: tp.Optional[bool] = None, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `stoch_1d_nb`.""" fast_k_window_ = to_1d_array_nb(np.asarray(fast_k_window)) slow_k_window_ = to_1d_array_nb(np.asarray(slow_k_window)) slow_d_window_ = to_1d_array_nb(np.asarray(slow_d_window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) if slow_k_wtype is not None: slow_k_wtype_ = to_1d_array_nb(np.asarray(slow_k_wtype)) else: slow_k_wtype_ = wtype_ if slow_d_wtype is not None: slow_d_wtype_ = to_1d_array_nb(np.asarray(slow_d_wtype)) else: slow_d_wtype_ = wtype_ fast_k = np.empty(close.shape, dtype=float_) slow_k = np.empty(close.shape, dtype=float_) slow_d = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): fast_k[:, col], slow_k[:, col], slow_d[:, col] = stoch_1d_nb( high=high[:, col], low=low[:, col], close=close[:, col], fast_k_window=flex_select_1d_nb(fast_k_window_, col), slow_k_window=flex_select_1d_nb(slow_k_window_, col), slow_d_window=flex_select_1d_nb(slow_d_window_, col), wtype=flex_select_1d_nb(wtype_, col), slow_k_wtype=flex_select_1d_nb(slow_k_wtype_, col), slow_d_wtype=flex_select_1d_nb(slow_d_wtype_, col), minp=minp, fast_k_minp=fast_k_minp, slow_k_minp=slow_k_minp, slow_d_minp=slow_d_minp, adjust=adjust, slow_k_adjust=slow_k_adjust, slow_d_adjust=slow_d_adjust, ) return fast_k, slow_k, slow_d # ############# MACD ############# # @register_jitted(cache=True) def macd_1d_nb( close: tp.Array1d, fast_window: int = 12, slow_window: int = 26, signal_window: int = 9, wtype: int = generic_enums.WType.Exp, macd_wtype: tp.Optional[int] = None, signal_wtype: tp.Optional[int] = None, minp: tp.Optional[int] = None, macd_minp: tp.Optional[int] = None, signal_minp: tp.Optional[int] = None, adjust: bool = False, macd_adjust: tp.Optional[bool] = None, signal_adjust: tp.Optional[bool] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """MACD. Returns the MACD and the signal. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" if macd_wtype is not None: macd_wtype_ = macd_wtype else: macd_wtype_ = wtype if signal_wtype is not None: signal_wtype_ = signal_wtype else: signal_wtype_ = wtype if macd_minp is not None: macd_minp_ = macd_minp else: macd_minp_ = minp if signal_minp is not None: signal_minp_ = signal_minp else: signal_minp_ = minp if macd_adjust is not None: macd_adjust_ = macd_adjust else: macd_adjust_ = adjust if signal_adjust is not None: signal_adjust_ = signal_adjust else: signal_adjust_ = adjust fast_ma = ma_1d_nb(close, window=fast_window, wtype=macd_wtype_, minp=macd_minp_, adjust=macd_adjust_) slow_ma = ma_1d_nb(close, window=slow_window, wtype=macd_wtype_, minp=macd_minp_, adjust=macd_adjust_) macd = fast_ma - slow_ma signal = ma_1d_nb(macd, window=signal_window, wtype=signal_wtype_, minp=signal_minp_, adjust=signal_adjust_) return macd, signal @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), fast_window=base_ch.FlexArraySlicer(), slow_window=base_ch.FlexArraySlicer(), signal_window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), macd_wtype=base_ch.FlexArraySlicer(), signal_wtype=base_ch.FlexArraySlicer(), minp=None, macd_minp=None, signal_minp=None, adjust=None, macd_adjust=None, signal_adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def macd_nb( close: tp.Array2d, fast_window: tp.FlexArray1dLike = 12, slow_window: tp.FlexArray1dLike = 26, signal_window: tp.FlexArray1dLike = 9, wtype: tp.FlexArray1dLike = generic_enums.WType.Exp, macd_wtype: tp.Optional[tp.FlexArray1dLike] = None, signal_wtype: tp.Optional[tp.FlexArray1dLike] = None, minp: tp.Optional[int] = None, macd_minp: tp.Optional[int] = None, signal_minp: tp.Optional[int] = None, adjust: bool = False, macd_adjust: tp.Optional[bool] = None, signal_adjust: tp.Optional[bool] = None, ) -> tp.Tuple[tp.Array2d, tp.Array2d]: """2-dim version of `macd_1d_nb`.""" fast_window_ = to_1d_array_nb(np.asarray(fast_window)) slow_window_ = to_1d_array_nb(np.asarray(slow_window)) signal_window_ = to_1d_array_nb(np.asarray(signal_window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) if macd_wtype is not None: macd_wtype_ = to_1d_array_nb(np.asarray(macd_wtype)) else: macd_wtype_ = wtype_ if signal_wtype is not None: signal_wtype_ = to_1d_array_nb(np.asarray(signal_wtype)) else: signal_wtype_ = wtype_ macd = np.empty(close.shape, dtype=float_) signal = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): macd[:, col], signal[:, col] = macd_1d_nb( close=close[:, col], fast_window=flex_select_1d_nb(fast_window_, col), slow_window=flex_select_1d_nb(slow_window_, col), signal_window=flex_select_1d_nb(signal_window_, col), wtype=flex_select_1d_nb(wtype_, col), macd_wtype=flex_select_1d_nb(macd_wtype_, col), signal_wtype=flex_select_1d_nb(signal_wtype_, col), minp=minp, macd_minp=macd_minp, signal_minp=signal_minp, adjust=adjust, macd_adjust=macd_adjust, signal_adjust=signal_adjust, ) return macd, signal @register_jitted(cache=True) def macd_hist_1d_nb(macd: tp.Array1d, signal: tp.Array1d) -> tp.Array1d: """MACD histogram.""" return macd - signal @register_chunkable( size=ch.ArraySizer(arg_query="macd", axis=1), arg_take_spec=dict( macd=ch.ArraySlicer(axis=1), signal=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def macd_hist_nb(macd: tp.Array2d, signal: tp.Array2d) -> tp.Array2d: """2-dim version of `macd_hist_1d_nb`.""" macd_hist = np.empty(macd.shape, dtype=float_) for col in prange(macd.shape[1]): macd_hist[:, col] = macd_hist_1d_nb(macd[:, col], signal[:, col]) return macd_hist # ############# ATR ############# # @register_jitted(cache=True) def iter_tr_nb(high: float, low: float, prev_close: float) -> float: """True Range (TR) at one iteration.""" tr0 = abs(high - low) tr1 = abs(high - prev_close) tr2 = abs(low - prev_close) if np.isnan(tr0) or np.isnan(tr1) or np.isnan(tr2): tr = np.nan else: tr = max(tr0, tr1, tr2) return tr @register_jitted(cache=True) def tr_1d_nb(high: tp.Array1d, low: tp.Array1d, close: tp.Array1d) -> tp.Array1d: """True Range (TR).""" tr = np.empty(close.shape, dtype=float_) for i in range(close.shape[0]): tr[i] = iter_tr_nb(high[i], low[i], close[i - 1] if i > 0 else np.nan) return tr @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def tr_nb(high: tp.Array2d, low: tp.Array2d, close: tp.Array2d) -> tp.Array2d: """2-dim version of `tr_1d_nb`.""" tr = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): tr[:, col] = tr_1d_nb(high[:, col], low[:, col], close[:, col]) return tr @register_jitted(cache=True) def atr_1d_nb( high: tp.Array1d, low: tp.Array1d, close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Average True Range (ATR). Returns TR and ATR. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" tr = tr_1d_nb(high, low, close) atr = ma_1d_nb(tr, window, wtype=wtype, minp=minp, adjust=adjust) return tr, atr @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def atr_nb( high: tp.Array2d, low: tp.Array2d, close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Tuple[tp.Array2d, tp.Array2d]: """2-dim version of `atr_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) tr = np.empty(close.shape, dtype=float_) atr = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): tr[:, col], atr[:, col] = atr_1d_nb( high[:, col], low[:, col], close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ) return tr, atr # ############# ADX ############# # @register_jitted(cache=True) def adx_1d_nb( high: tp.Array1d, low: tp.Array1d, close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Average Directional Movement Index (ADX). Returns +DI, -DI, DX, and ADX. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" _, atr = atr_1d_nb( high, low, close, window=window, wtype=wtype, minp=minp, adjust=adjust, ) dm_plus = np.empty(close.shape, dtype=float_) dm_minus = np.empty(close.shape, dtype=float_) for i in range(close.shape[0]): up_change = np.nan if i == 0 else high[i] - high[i - 1] down_change = np.nan if i == 0 else low[i - 1] - low[i] if up_change > down_change and up_change > 0: dm_plus[i] = up_change else: dm_plus[i] = 0.0 if down_change > up_change and down_change > 0: dm_minus[i] = down_change else: dm_minus[i] = 0.0 dm_plus_smoothed = ma_1d_nb(dm_plus, window, wtype=wtype, minp=minp, adjust=adjust) dm_minus_smoothed = ma_1d_nb(dm_minus, window, wtype=wtype, minp=minp, adjust=adjust) plus_di = 100 * dm_plus_smoothed / atr minus_di = 100 * dm_minus_smoothed / atr dx = 100 * np.abs(plus_di - minus_di) / (plus_di + minus_di) adx = ma_1d_nb(dx, window, wtype=wtype, minp=minp, adjust=adjust) return plus_di, minus_di, dx, adx @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def adx_nb( high: tp.Array2d, low: tp.Array2d, close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Wilder, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `adx_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) plus_di = np.empty(close.shape, dtype=float_) minus_di = np.empty(close.shape, dtype=float_) dx = np.empty(close.shape, dtype=float_) adx = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): plus_di[:, col], minus_di[:, col], dx[:, col], adx[:, col] = adx_1d_nb( high[:, col], low[:, col], close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), minp=minp, adjust=adjust, ) return plus_di, minus_di, dx, adx # ############# OBV ############# # @register_jitted(cache=True) def obv_1d_nb(close: tp.Array1d, volume: tp.Array1d) -> tp.Array1d: """On-Balance Volume (OBV).""" obv = np.empty(close.shape, dtype=float_) cumsum = 0.0 for i in range(close.shape[0]): prev_close = close[i - 1] if i > 0 else np.nan if close[i] < prev_close: value = -volume[i] else: value = volume[i] if not np.isnan(value): cumsum += value obv[i] = cumsum return obv @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), volume=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def obv_nb(close: tp.Array2d, volume: tp.Array2d) -> tp.Array2d: """2-dim version of `obv_1d_nb`.""" obv = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): obv[:, col] = obv_1d_nb(close[:, col], volume[:, col]) return obv # ############# OLS ############# # @register_jitted(cache=True) def ols_1d_nb( x: tp.Array1d, y: tp.Array1d, window: int = 14, norm_window: tp.Optional[int] = None, minp: tp.Optional[int] = None, ddof: int = 0, with_zscore: bool = True, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]: """Rolling Ordinary Least Squares (OLS).""" if norm_window is not None: norm_window_ = norm_window else: norm_window_ = window slope, intercept = generic_nb.rolling_ols_1d_nb(x, y, window, minp=minp) if with_zscore: pred = intercept + slope * x error = y - pred error_mean = generic_nb.rolling_mean_1d_nb(error, norm_window_, minp=minp) error_std = generic_nb.rolling_std_1d_nb(error, norm_window_, minp=minp, ddof=ddof) zscore = (error - error_mean) / error_std else: zscore = np.full(x.shape, np.nan, dtype=float_) return slope, intercept, zscore @register_chunkable( size=ch.ArraySizer(arg_query="x", axis=1), arg_take_spec=dict( x=ch.ArraySlicer(axis=1), y=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), norm_window=base_ch.FlexArraySlicer(), minp=None, ddof=None, with_zscore=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ols_nb( x: tp.Array2d, y: tp.Array2d, window: tp.FlexArray1dLike = 14, norm_window: tp.Optional[tp.FlexArray1dLike] = None, minp: tp.Optional[int] = None, ddof: int = 0, with_zscore: bool = True, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `ols_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) if norm_window is not None: norm_window_ = to_1d_array_nb(np.asarray(norm_window)) else: norm_window_ = window_ slope = np.empty(x.shape, dtype=float_) intercept = np.empty(x.shape, dtype=float_) zscore = np.empty(x.shape, dtype=float_) for col in prange(x.shape[1]): slope[:, col], intercept[:, col], zscore[:, col] = ols_1d_nb( x[:, col], y[:, col], window=flex_select_1d_nb(window_, col), norm_window=flex_select_1d_nb(norm_window_, col), minp=minp, ddof=ddof, with_zscore=with_zscore, ) return slope, intercept, zscore @register_jitted(cache=True) def ols_pred_1d_nb(x: tp.Array1d, slope: tp.Array1d, intercept: tp.Array1d) -> tp.Array1d: """OLS prediction.""" return intercept + slope * x @register_chunkable( size=ch.ArraySizer(arg_query="x", axis=1), arg_take_spec=dict( x=ch.ArraySlicer(axis=1), slope=ch.ArraySlicer(axis=1), intercept=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ols_pred_nb(x: tp.Array2d, slope: tp.Array2d, intercept: tp.Array2d) -> tp.Array2d: """2-dim version of `ols_pred_1d_nb`.""" pred = np.empty(x.shape, dtype=float_) for col in prange(x.shape[1]): pred[:, col] = ols_pred_1d_nb(x[:, col], slope[:, col], intercept[:, col]) return pred @register_jitted(cache=True) def ols_error_1d_nb(y: tp.Array1d, pred: tp.Array1d) -> tp.Array1d: """OLS error.""" return y - pred @register_chunkable( size=ch.ArraySizer(arg_query="y", axis=1), arg_take_spec=dict( y=ch.ArraySlicer(axis=1), pred=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ols_error_nb(y: tp.Array2d, pred: tp.Array2d) -> tp.Array2d: """2-dim version of `ols_error_1d_nb`.""" error = np.empty(y.shape, dtype=float_) for col in prange(y.shape[1]): error[:, col] = ols_error_1d_nb(y[:, col], pred[:, col]) return error @register_jitted(cache=True) def ols_angle_1d_nb(slope: tp.Array1d) -> tp.Array1d: """OLS angle.""" return np.arctan(slope) * 180 / np.pi @register_chunkable( size=ch.ArraySizer(arg_query="slope", axis=1), arg_take_spec=dict( slope=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def ols_angle_nb(slope: tp.Array2d) -> tp.Array2d: """2-dim version of `ols_angle_1d_nb`.""" angle = np.empty(slope.shape, dtype=float_) for col in prange(slope.shape[1]): angle[:, col] = ols_angle_1d_nb(slope[:, col]) return angle # ############# VWAP ############# # @register_jitted(cache=True) def typical_price_1d_nb(high: tp.Array1d, low: tp.Array1d, close: tp.Array1d) -> tp.Array1d: """Typical price.""" return (high + low + close) / 3 @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def typical_price_nb(high: tp.Array2d, low: tp.Array2d, close: tp.Array2d) -> tp.Array2d: """2-dim version of `typical_price_1d_nb`.""" typical_price = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): typical_price[:, col] = typical_price_1d_nb(high[:, col], low[:, col], close[:, col]) return typical_price @register_jitted(cache=True) def vwap_1d_nb( high: tp.Array1d, low: tp.Array1d, close: tp.Array1d, volume: tp.Array1d, group_lens: tp.GroupLens, ) -> tp.Array1d: """Volume-Weighted Average Price (VWAP).""" group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens out = np.full(volume.shape, np.nan, dtype=float_) typical_price = typical_price_1d_nb(high, low, close) for group in range(len(group_lens)): from_i = group_start_idxs[group] to_i = group_end_idxs[group] nom_cumsum = 0 denum_cumsum = 0 for i in range(from_i, to_i): nom_cumsum += volume[i] * typical_price[i] denum_cumsum += volume[i] if denum_cumsum == 0: out[i] = np.nan else: out[i] = nom_cumsum / denum_cumsum return out @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), volume=ch.ArraySlicer(axis=1), group_lens=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def vwap_nb( high: tp.Array2d, low: tp.Array2d, close: tp.Array2d, volume: tp.Array2d, group_lens: tp.GroupLens, ) -> tp.Array2d: """2-dim version of `vwap_1d_nb`.""" vwap = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): vwap[:, col] = vwap_1d_nb( high[:, col], low[:, col], close[:, col], volume[:, col], group_lens, ) return vwap # ############# PIVOTINFO ############# # @register_jitted(cache=True) def pivot_info_1d_nb( high: tp.Array1d, low: tp.Array1d, up_th: tp.FlexArray1dLike, down_th: tp.FlexArray1dLike, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Pivot information.""" up_th_ = to_1d_array_nb(np.asarray(up_th)) down_th_ = to_1d_array_nb(np.asarray(down_th)) conf_pivot = np.empty(high.shape, dtype=int_) conf_idx = np.empty(high.shape, dtype=int_) last_pivot = np.empty(high.shape, dtype=int_) last_idx = np.empty(high.shape, dtype=int_) _conf_pivot = 0 _conf_idx = -1 _conf_value = np.nan _last_pivot = 0 _last_idx = -1 _last_value = np.nan first_valid_idx = -1 for i in range(high.shape[0]): if not np.isnan(high[i]) and not np.isnan(low[i]): if first_valid_idx == -1: _up_th = 1 + abs(flex_select_1d_nb(up_th_, i)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, i)) if np.isnan(_up_th) or np.isnan(_down_th): conf_pivot[i] = _conf_pivot conf_idx[i] = _conf_idx last_pivot[i] = _last_pivot last_idx[i] = _last_idx continue first_valid_idx = i if _last_idx == -1: _up_th = 1 + abs(flex_select_1d_nb(up_th_, first_valid_idx)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, first_valid_idx)) if not np.isnan(_up_th) and high[i] >= low[first_valid_idx] * _up_th: if not np.isnan(_down_th) and low[i] <= high[first_valid_idx] * _down_th: pass # wait else: _conf_pivot = Pivot.Valley _conf_idx = first_valid_idx _conf_value = low[first_valid_idx] _last_pivot = Pivot.Peak _last_idx = i _last_value = high[i] if not np.isnan(_down_th) and low[i] <= high[first_valid_idx] * _down_th: if not np.isnan(_up_th) and high[i] >= low[first_valid_idx] * _up_th: pass # wait else: _conf_pivot = Pivot.Peak _conf_idx = first_valid_idx _conf_value = high[first_valid_idx] _last_pivot = Pivot.Valley _last_idx = i _last_value = low[i] else: _up_th = 1 + abs(flex_select_1d_nb(up_th_, _last_idx)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, _last_idx)) if _last_pivot == Pivot.Valley: if not np.isnan(_last_value) and not np.isnan(_up_th) and high[i] >= _last_value * _up_th: _conf_pivot = _last_pivot _conf_idx = _last_idx _conf_value = _last_value _last_pivot = Pivot.Peak _last_idx = i _last_value = high[i] elif np.isnan(_last_value) or low[i] < _last_value: _last_idx = i _last_value = low[i] elif _last_pivot == Pivot.Peak: if not np.isnan(_last_value) and not np.isnan(_down_th) and low[i] <= _last_value * _down_th: _conf_pivot = _last_pivot _conf_idx = _last_idx _conf_value = _last_value _last_pivot = Pivot.Valley _last_idx = i _last_value = low[i] elif np.isnan(_last_value) or high[i] > _last_value: _last_idx = i _last_value = high[i] conf_pivot[i] = _conf_pivot conf_idx[i] = _conf_idx last_pivot[i] = _last_pivot last_idx[i] = _last_idx return conf_pivot, conf_idx, last_pivot, last_idx @register_chunkable( size=ch.ArraySizer(arg_query="high", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), up_th=base_ch.FlexArraySlicer(axis=1), down_th=base_ch.FlexArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def pivot_info_nb( high: tp.Array2d, low: tp.Array2d, up_th: tp.FlexArray2dLike, down_th: tp.FlexArray2dLike, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `pivot_info_1d_nb`.""" up_th_ = to_2d_array_nb(np.asarray(up_th)) down_th_ = to_2d_array_nb(np.asarray(down_th)) conf_pivot = np.empty(high.shape, dtype=int_) conf_idx = np.empty(high.shape, dtype=int_) last_pivot = np.empty(high.shape, dtype=int_) last_idx = np.empty(high.shape, dtype=int_) for col in prange(high.shape[1]): conf_pivot[:, col], conf_idx[:, col], last_pivot[:, col], last_idx[:, col] = pivot_info_1d_nb( high[:, col], low[:, col], flex_select_col_nb(up_th_, col), flex_select_col_nb(down_th_, col), ) return conf_pivot, conf_idx, last_pivot, last_idx @register_jitted(cache=True) def pivot_value_1d_nb(high: tp.Array1d, low: tp.Array1d, last_pivot: tp.Array1d, last_idx: tp.Array1d) -> tp.Array1d: """Pivot value.""" pivot_value = np.empty(high.shape, dtype=float_) for i in range(high.shape[0]): if last_pivot[i] == Pivot.Peak: pivot_value[i] = high[last_idx[i]] elif last_pivot[i] == Pivot.Valley: pivot_value[i] = low[last_idx[i]] else: pivot_value[i] = np.nan return pivot_value @register_chunkable( size=ch.ArraySizer(arg_query="high", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), last_pivot=ch.ArraySlicer(axis=1), last_idx=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def pivot_value_nb(high: tp.Array2d, low: tp.Array2d, last_pivot: tp.Array2d, last_idx: tp.Array2d) -> tp.Array2d: """2-dim version of `pivot_value_1d_nb`.""" pivot_value = np.empty(high.shape, dtype=float_) for col in prange(high.shape[1]): pivot_value[:, col] = pivot_value_1d_nb(high[:, col], low[:, col], last_pivot[:, col], last_idx[:, col]) return pivot_value @register_jitted(cache=True) def pivots_1d_nb(conf_pivot: tp.Array1d, conf_idx: tp.Array1d, last_pivot: tp.Array1d) -> tp.Array1d: """Pivots. !!! warning To be used in plotting. Do not use it as an indicator!""" pivots = np.zeros(conf_pivot.shape, dtype=int_) for i in range(conf_pivot.shape[0] - 1): pivots[conf_idx[i]] = conf_pivot[i] pivots[-1] = last_pivot[-1] return pivots @register_chunkable( size=ch.ArraySizer(arg_query="conf_pivot", axis=1), arg_take_spec=dict( conf_pivot=ch.ArraySlicer(axis=1), conf_idx=ch.ArraySlicer(axis=1), last_pivot=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def pivots_nb(conf_pivot: tp.Array2d, conf_idx: tp.Array2d, last_pivot: tp.Array2d) -> tp.Array2d: """2-dim version of `pivots_1d_nb`.""" pivots = np.empty(conf_pivot.shape, dtype=int_) for col in prange(conf_pivot.shape[1]): pivots[:, col] = pivots_1d_nb(conf_pivot[:, col], conf_idx[:, col], last_pivot[:, col]) return pivots @register_jitted(cache=True) def modes_1d_nb(pivots: tp.Array1d) -> tp.Array1d: """Modes. !!! warning To be used in plotting. Do not use it as an indicator!""" modes = np.empty(pivots.shape, dtype=int_) mode = 0 for i in range(pivots.shape[0]): if pivots[i] != 0: mode = -pivots[i] modes[i] = mode return modes @register_chunkable( size=ch.ArraySizer(arg_query="pivots", axis=1), arg_take_spec=dict( pivots=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def modes_nb(pivots: tp.Array2d) -> tp.Array2d: """2-dim version of `modes_1d_nb`.""" modes = np.empty(pivots.shape, dtype=int_) for col in prange(pivots.shape[1]): modes[:, col] = modes_1d_nb(pivots[:, col]) return modes # ############# SUPERTREND ############# # @register_jitted(cache=True) def iter_med_price_nb(high: float, low: float) -> float: """Median price at one iteration.""" return (high + low) / 2 @register_jitted(cache=True) def iter_basic_bands_nb(high: float, low: float, atr: float, multiplier: float) -> tp.Tuple[float, float]: """Upper and lower bands at one iteration.""" med_price = iter_med_price_nb(high, low) matr = multiplier * atr upper = med_price + matr lower = med_price - matr return upper, lower @register_jitted(cache=True) def final_basic_bands_nb( close: float, upper: float, lower: float, prev_upper: float, prev_lower: float, prev_direction: int, ) -> tp.Tuple[float, float, float, int, float, float]: """Final bands at one iteration.""" if close > prev_upper: direction = 1 elif close < prev_lower: direction = -1 else: direction = prev_direction if direction > 0 and lower < prev_lower: lower = prev_lower if direction < 0 and upper > prev_upper: upper = prev_upper if direction > 0: trend = long = lower short = np.nan else: trend = short = upper long = np.nan return upper, lower, trend, direction, long, short @register_jitted(cache=True) def supertrend_acc_nb(in_state: SuperTrendAIS) -> SuperTrendAOS: """Accumulator of `supertrend_nb`. Takes a state of type `vectorbtpro.indicators.enums.SuperTrendAIS` and returns a state of type `vectorbtpro.indicators.enums.SuperTrendAOS`.""" i = in_state.i high = in_state.high low = in_state.low close = in_state.close prev_close = in_state.prev_close prev_upper = in_state.prev_upper prev_lower = in_state.prev_lower prev_direction = in_state.prev_direction nobs = in_state.nobs weighted_avg = in_state.weighted_avg old_wt = in_state.old_wt period = in_state.period multiplier = in_state.multiplier tr = iter_tr_nb(high, low, prev_close) alpha = generic_nb.alpha_from_wilder_nb(period) ewm_mean_in_state = generic_enums.EWMMeanAIS( i=i, value=tr, old_wt=old_wt, weighted_avg=weighted_avg, nobs=nobs, alpha=alpha, minp=period, adjust=False, ) ewm_mean_out_state = generic_nb.ewm_mean_acc_nb(ewm_mean_in_state) atr = ewm_mean_out_state.value upper, lower = iter_basic_bands_nb(high, low, atr, multiplier) if i == 0: trend, direction, long, short = np.nan, 1, np.nan, np.nan else: upper, lower, trend, direction, long, short = final_basic_bands_nb( close, upper, lower, prev_upper, prev_lower, prev_direction, ) return SuperTrendAOS( nobs=ewm_mean_out_state.nobs, weighted_avg=ewm_mean_out_state.weighted_avg, old_wt=ewm_mean_out_state.old_wt, upper=upper, lower=lower, trend=trend, direction=direction, long=long, short=short, ) @register_jitted(cache=True) def supertrend_1d_nb( high: tp.Array1d, low: tp.Array1d, close: tp.Array1d, period: int = 7, multiplier: float = 3.0, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Supertrend.""" trend = np.empty(close.shape, dtype=float_) direction = np.empty(close.shape, dtype=int_) long = np.empty(close.shape, dtype=float_) short = np.empty(close.shape, dtype=float_) if close.shape[0] == 0: return trend, direction, long, short nobs = 0 old_wt = 1.0 weighted_avg = np.nan prev_upper = np.nan prev_lower = np.nan for i in range(close.shape[0]): in_state = SuperTrendAIS( i=i, high=high[i], low=low[i], close=close[i], prev_close=close[i - 1] if i > 0 else np.nan, prev_upper=prev_upper, prev_lower=prev_lower, prev_direction=direction[i - 1] if i > 0 else 1, nobs=nobs, weighted_avg=weighted_avg, old_wt=old_wt, period=period, multiplier=multiplier, ) out_state = supertrend_acc_nb(in_state) nobs = out_state.nobs weighted_avg = out_state.weighted_avg old_wt = out_state.old_wt prev_upper = out_state.upper prev_lower = out_state.lower trend[i] = out_state.trend direction[i] = out_state.direction long[i] = out_state.long short[i] = out_state.short return trend, direction, long, short @register_chunkable( size=ch.ArraySizer(arg_query="high", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), close=ch.ArraySlicer(axis=1), period=base_ch.FlexArraySlicer(), multiplier=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def supertrend_nb( high: tp.Array2d, low: tp.Array2d, close: tp.Array2d, period: tp.FlexArray1dLike = 7, multiplier: tp.FlexArray1dLike = 3.0, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `supertrend_1d_nb`.""" period_ = to_1d_array_nb(np.asarray(period)) multiplier_ = to_1d_array_nb(np.asarray(multiplier)) trend = np.empty(close.shape, dtype=float_) direction = np.empty(close.shape, dtype=int_) long = np.empty(close.shape, dtype=float_) short = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): trend[:, col], direction[:, col], long[:, col], short[:, col] = supertrend_1d_nb( high[:, col], low[:, col], close[:, col], period=flex_select_1d_nb(period_, col), multiplier=flex_select_1d_nb(multiplier_, col), ) return trend, direction, long, short # ############# SIGDET ############# # @register_jitted(cache=True) def signal_detection_1d_nb( close: tp.Array1d, lag: int = 14, factor: tp.FlexArray1dLike = 1.0, influence: tp.FlexArray1dLike = 1.0, up_factor: tp.Optional[tp.FlexArray1dLike] = None, down_factor: tp.Optional[tp.FlexArray1dLike] = None, mean_influence: tp.Optional[tp.FlexArray1dLike] = None, std_influence: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]: """Signal detection.""" factor_ = to_1d_array_nb(np.asarray(factor)) influence_ = to_1d_array_nb(np.asarray(influence)) if up_factor is not None: up_factor_ = to_1d_array_nb(np.asarray(up_factor)) else: up_factor_ = factor_ if down_factor is not None: down_factor_ = to_1d_array_nb(np.asarray(down_factor)) else: down_factor_ = factor_ if mean_influence is not None: mean_influence_ = to_1d_array_nb(np.asarray(mean_influence)) else: mean_influence_ = influence_ if std_influence is not None: std_influence_ = to_1d_array_nb(np.asarray(std_influence)) else: std_influence_ = influence_ signal = np.full(close.shape, 0, dtype=int_) close_mean_filter = close.astype(float_) close_std_filter = close.astype(float_) mean_filter = np.full(close.shape, np.nan, dtype=float_) std_filter = np.full(close.shape, np.nan, dtype=float_) upper_band = np.full(close.shape, np.nan, dtype=float_) lower_band = np.full(close.shape, np.nan, dtype=float_) if lag == 0: raise ValueError("Lag cannot be zero") if lag - 1 >= close.shape[0]: raise ValueError("Lag must be smaller than close") mean_filter[lag - 1] = np.nanmean(close[:lag]) std_filter[lag - 1] = np.nanstd(close[:lag]) for i in range(lag, close.shape[0]): _up_factor = abs(flex_select_1d_nb(up_factor_, i)) _down_factor = abs(flex_select_1d_nb(down_factor_, i)) _mean_influence = abs(flex_select_1d_nb(mean_influence_, i)) _std_influence = abs(flex_select_1d_nb(std_influence_, i)) up_crossed = close[i] - mean_filter[i - 1] >= _up_factor * std_filter[i - 1] down_crossed = close[i] - mean_filter[i - 1] <= -_down_factor * std_filter[i - 1] if up_crossed or down_crossed: if up_crossed: signal[i] = 1 else: signal[i] = -1 close_mean_filter[i] = _mean_influence * close[i] + (1 - _mean_influence) * close_mean_filter[i - 1] close_std_filter[i] = _std_influence * close[i] + (1 - _std_influence) * close_std_filter[i - 1] else: signal[i] = 0 close_mean_filter[i] = close[i] close_std_filter[i] = close[i] mean_filter[i] = np.nanmean(close_mean_filter[(i - lag + 1) : i + 1]) std_filter[i] = np.nanstd(close_std_filter[(i - lag + 1) : i + 1]) upper_band[i] = mean_filter[i] + _up_factor * std_filter[i - 1] lower_band[i] = mean_filter[i] - _down_factor * std_filter[i - 1] return signal, upper_band, lower_band @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), lag=base_ch.FlexArraySlicer(), factor=base_ch.FlexArraySlicer(axis=1), influence=base_ch.FlexArraySlicer(axis=1), up_factor=base_ch.FlexArraySlicer(axis=1), down_factor=base_ch.FlexArraySlicer(axis=1), mean_influence=base_ch.FlexArraySlicer(axis=1), std_influence=base_ch.FlexArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def signal_detection_nb( close: tp.Array2d, lag: tp.FlexArray1dLike = 14, factor: tp.FlexArray2dLike = 1.0, influence: tp.FlexArray2dLike = 1.0, up_factor: tp.Optional[tp.FlexArray2dLike] = None, down_factor: tp.Optional[tp.FlexArray2dLike] = None, mean_influence: tp.Optional[tp.FlexArray2dLike] = None, std_influence: tp.Optional[tp.FlexArray2dLike] = None, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `signal_detection_1d_nb`.""" lag_ = to_1d_array_nb(np.asarray(lag)) factor_ = to_2d_array_nb(np.asarray(factor)) influence_ = to_2d_array_nb(np.asarray(influence)) if up_factor is not None: up_factor_ = to_2d_array_nb(np.asarray(up_factor)) else: up_factor_ = factor_ if down_factor is not None: down_factor_ = to_2d_array_nb(np.asarray(down_factor)) else: down_factor_ = factor_ if mean_influence is not None: mean_influence_ = to_2d_array_nb(np.asarray(mean_influence)) else: mean_influence_ = influence_ if std_influence is not None: std_influence_ = to_2d_array_nb(np.asarray(std_influence)) else: std_influence_ = influence_ signal = np.empty(close.shape, dtype=int_) upper_band = np.empty(close.shape, dtype=float_) lower_band = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): signal[:, col], upper_band[:, col], lower_band[:, col] = signal_detection_1d_nb( close[:, col], lag=flex_select_1d_nb(lag_, col), factor=flex_select_col_nb(factor_, col), influence=flex_select_col_nb(influence_, col), up_factor=flex_select_col_nb(up_factor_, col), down_factor=flex_select_col_nb(down_factor_, col), mean_influence=flex_select_col_nb(mean_influence_, col), std_influence=flex_select_col_nb(std_influence_, col), ) return signal, upper_band, lower_band # ############# HURST ############# # @register_jitted(cache=True) def get_standard_hurst_nb( close: tp.Array1d, max_lag: int = 20, stabilize: bool = False, ) -> float: """Estimate the Hurst exponent using standard method.""" if max_lag is None: lags = np.arange(2, len(close) - 1) else: lags = np.arange(2, min(max_lag, len(close) - 1)) tau = np.empty(len(lags), dtype=float_) for i, lag in enumerate(lags): tau[i] = np.var(np.subtract(close[lag:], close[:-lag])) coef = generic_nb.polyfit_1d_nb(np.log(lags), np.log(tau), 1, stabilize=stabilize) return coef[0] / 2 @register_jitted(cache=True) def get_rs_nb(close: tp.Array1d) -> float: """Get rescaled range (R/S) for Hurst exponent estimation.""" incs = close[1:] / close[:-1] - 1.0 mean_inc = np.sum(incs) / len(incs) deviations = incs - mean_inc Z = np.cumsum(deviations) R = np.max(Z) - np.min(Z) S = generic_nb.nanstd_1d_nb(incs, ddof=1) if R == 0 or S == 0: return 0 return R / S @register_jitted(cache=True) def get_log_rs_hurst_nb( close: tp.Array1d, min_log: int = 1, max_log: int = 2, log_step: int = 0.25, ) -> float: """Estimate the Hurst exponent using R/S method. Windows are log-distributed.""" max_log = min(max_log, np.log10(len(close) - 1)) log_range = np.arange(min_log, max_log, log_step) windows = np.empty(len(log_range) + 1, dtype=int_) windows[: len(log_range)] = 10**log_range windows[-1] = len(close) RS = np.empty(len(windows), dtype=float_) W = np.empty(len(windows), dtype=int_) k = 0 for i, w in enumerate(windows): rs_sum = 0.0 rs_count = 0 for start in range(0, len(close), w): if (start + w) > len(close): break rs = get_rs_nb(close[start : start + w]) if rs != 0: rs_sum += rs rs_count += 1 if rs_count != 0: RS[k] = rs_sum / rs_count W[k] = w k += 1 if k == 0: return np.nan A = np.vstack((np.log10(W[:k]), np.ones(len(RS[:k])))).T H, c = np.linalg.lstsq(A, np.log10(RS[:k]), rcond=-1)[0] return H @register_jitted(cache=True) def get_rs_hurst_nb( close: tp.Array1d, min_chunk: int = 8, max_chunk: int = 100, num_chunks: int = 5, ) -> float: """Estimate the Hurst exponent using R/S method. Windows are linearly distributed.""" diff = close[1:] - close[:-1] N = len(diff) max_chunk += 1 max_chunk = min(max_chunk, len(diff) - 1) rs_tmp = np.empty(N, dtype=float_) chunk_size_range = np.linspace(min_chunk, max_chunk, num_chunks).astype(int_) chunk_size_list = np.empty(len(chunk_size_range), dtype=int_) rs_values_list = np.empty(len(chunk_size_range), dtype=float_) k = 0 for chunk_size in chunk_size_range: number_of_chunks = int(len(diff) / chunk_size) for idx in range(number_of_chunks): ini = idx * chunk_size end = ini + chunk_size chunk = diff[ini:end] z = np.cumsum(chunk - np.mean(chunk)) rs_tmp[idx] = np.divide(np.max(z) - np.min(z), np.nanstd(chunk)) rs = np.nanmean(rs_tmp[: idx + 1]) if not np.isnan(rs) and rs != 0: chunk_size_list[k] = chunk_size rs_values_list[k] = rs k += 1 H, c = np.linalg.lstsq( a=np.vstack((np.log(chunk_size_list[:k]), np.ones(len(chunk_size_list[:k])))).T, b=np.log(rs_values_list[:k]), rcond=-1, )[0] return H @register_jitted(cache=True) def get_dma_hurst_nb( close: tp.Array1d, min_chunk: int = 8, max_chunk: int = 100, num_chunks: int = 5, ) -> float: """Estimate the Hurst exponent using DMA method. Windows are linearly distributed.""" max_chunk += 1 max_chunk = min(max_chunk, len(close) - 1) N = len(close) n_range = np.linspace(min_chunk, max_chunk, num_chunks).astype(int_) n_list = np.empty(len(n_range), dtype=int_) dma_list = np.empty(len(n_range), dtype=float_) k = 0 factor = 1 / (N - max_chunk) for i, n in enumerate(n_range): x1 = np.full(n, -1, int_) x1[0] = n - 1 b = np.divide(x1, n) # do the same as: y - y_ma_n noise = np.power(generic_nb.fir_filter_1d_nb(b, close)[max_chunk:], 2) dma = np.sqrt(factor * np.sum(noise)) if not np.isnan(dma) and dma != 0: n_list[k] = n dma_list[k] = dma k += 1 if k == 0: return np.nan H, const = np.linalg.lstsq( a=np.vstack((np.log10(n_list[:k]), np.ones(len(n_list[:k])))).T, b=np.log10(dma_list[:k]), rcond=-1, )[0] return H @register_jitted(cache=True) def get_dsod_hurst_nb(close: tp.Array1d) -> float: """Estimate the Hurst exponent using discrete second order derivative.""" diff = close[1:] - close[:-1] y = np.cumsum(diff) b1 = [1, -2, 1] y1 = generic_nb.fir_filter_1d_nb(b1, y) y1 = y1[len(b1) - 1 :] b2 = [1, 0, -2, 0, 1] y2 = generic_nb.fir_filter_1d_nb(b2, y) y2 = y2[len(b2) - 1 :] s1 = np.mean(y1**2) s2 = np.mean(y2**2) return 0.5 * np.log2(s2 / s1) @register_jitted(cache=True) def get_hurst_nb( close: tp.Array1d, method: int = HurstMethod.Standard, max_lag: int = 20, min_log: int = 1, max_log: int = 2, log_step: int = 0.25, min_chunk: int = 8, max_chunk: int = 100, num_chunks: int = 5, stabilize: bool = False, ) -> float: """Estimate the Hurst exponent using various methods. Uses the following methods: * `HurstMethod.Standard`: `vectorbtpro.indicators.nb.get_standard_hurst_nb` * `HurstMethod.LogRS`: `vectorbtpro.indicators.nb.get_log_rs_hurst_nb` * `HurstMethod.RS`: `vectorbtpro.indicators.nb.get_rs_hurst_nb` * `HurstMethod.DMA`: `vectorbtpro.indicators.nb.get_dma_hurst_nb` * `HurstMethod.DSOD`: `vectorbtpro.indicators.nb.get_dsod_hurst_nb` """ if method == HurstMethod.Standard: return get_standard_hurst_nb(close, max_lag=max_lag, stabilize=stabilize) if method == HurstMethod.LogRS: return get_log_rs_hurst_nb(close, min_log=min_log, max_log=max_log, log_step=log_step) if method == HurstMethod.RS: return get_rs_hurst_nb(close, min_chunk=min_chunk, max_chunk=max_chunk, num_chunks=num_chunks) if method == HurstMethod.DMA: return get_dma_hurst_nb(close, min_chunk=min_chunk, max_chunk=max_chunk, num_chunks=num_chunks) if method == HurstMethod.DSOD: return get_dsod_hurst_nb(close) raise ValueError("Invalid HurstMethod option") @register_jitted(cache=True) def rolling_hurst_1d_nb( close: tp.Array1d, window: int, method: int = HurstMethod.Standard, max_lag: int = 20, min_log: int = 1, max_log: int = 2, log_step: int = 0.25, min_chunk: int = 8, max_chunk: int = 100, num_chunks: int = 5, minp: tp.Optional[int] = None, stabilize: bool = False, ) -> tp.Array1d: """Rolling version of `get_hurst_nb`. For `method`, see `vectorbtpro.indicators.enums.HurstMethod`.""" if minp is None: minp = window if minp > window: raise ValueError("minp must be <= window") out = np.empty_like(close, dtype=float_) nancnt = 0 for i in range(close.shape[0]): if np.isnan(close[i]): nancnt = nancnt + 1 if i < window: valid_cnt = i + 1 - nancnt else: if np.isnan(close[i - window]): nancnt = nancnt - 1 valid_cnt = window - nancnt if valid_cnt < minp: out[i] = np.nan else: from_i = max(0, i + 1 - window) to_i = i + 1 close_window = close[from_i:to_i] out[i] = get_hurst_nb( close_window, method=method, max_lag=max_lag, min_log=min_log, max_log=max_log, log_step=log_step, min_chunk=min_chunk, max_chunk=max_chunk, num_chunks=num_chunks, stabilize=stabilize, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=None, method=None, max_lag=None, min_log=None, max_log=None, log_step=None, min_chunk=None, max_chunk=None, num_chunks=None, minp=None, stabilize=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_hurst_nb( close: tp.Array2d, window: int, method: int = HurstMethod.Standard, max_lag: int = 20, min_log: int = 1, max_log: int = 2, log_step: int = 0.25, min_chunk: int = 8, max_chunk: int = 100, num_chunks: int = 5, minp: tp.Optional[int] = None, stabilize: bool = False, ) -> tp.Array2d: """2-dim version of `rolling_hurst_1d_nb`.""" out = np.empty_like(close, dtype=float_) for col in prange(close.shape[1]): out[:, col] = rolling_hurst_1d_nb( close[:, col], window, method=method, max_lag=max_lag, min_log=min_log, max_log=max_log, log_step=log_step, min_chunk=min_chunk, max_chunk=max_chunk, num_chunks=num_chunks, minp=minp, stabilize=stabilize, ) return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Helper functions for TA-Lib.""" import inspect import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.merging import column_stack_arrays from vectorbtpro.base.reshaping import to_pd_array, broadcast_arrays, broadcast from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.accessors import GenericAccessor from vectorbtpro.utils.array_ import build_nan_mask, squeeze_nan, unsqueeze_nan from vectorbtpro.utils.colors import adjust_opacity from vectorbtpro.utils.config import merge_dicts, resolve_dict from vectorbtpro.utils.warnings_ import warn __all__ = [ "talib_func", "talib_plot_func", ] def talib_func(func_name: str) -> tp.Callable: """Get the TA-Lib indicator function.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("talib") import talib from talib import abstract func_name = func_name.upper() talib_func = getattr(talib, func_name) info = abstract.Function(func_name).info input_names = [] for in_names in info["input_names"].values(): if isinstance(in_names, (list, tuple)): input_names.extend(list(in_names)) else: input_names.append(in_names) output_names = info["output_names"] one_output = len(output_names) == 1 param_names = list(info["parameters"].keys()) def run_talib_func( *args, timeframe: tp.Optional[tp.FrequencyLike] = None, resample_map: tp.KwargsLike = None, resample_kwargs: tp.KwargsLikeSequence = None, realign_kwargs: tp.KwargsLikeSequence = None, wrapper: tp.Optional[ArrayWrapper] = None, skipna: bool = False, silence_warnings: bool = False, broadcast_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, wrap: tp.Optional[bool] = None, unpack_to: tp.Optional[str] = None, **kwargs, ) -> tp.Union[tp.MaybeTuple[tp.AnyArray], tp.Dict[str, tp.AnyArray]]: if broadcast_kwargs is None: broadcast_kwargs = {} if wrap_kwargs is None: wrap_kwargs = {} inputs = [] other_args = [] for k in range(len(args)): if k < len(input_names) and len(inputs) < len(input_names): inputs.append(args[k]) else: other_args.append(args[k]) if len(inputs) < len(input_names): for k in input_names: if k in kwargs: inputs.append(kwargs.pop(k)) is_pandas = False common_type = None common_shape = None broadcasting_needed = False new_inputs = [] for input in inputs: if isinstance(input, (pd.Series, pd.DataFrame)): is_pandas = True elif not isinstance(input, np.ndarray): input = np.asarray(input) if common_type is None: common_type = type(input) elif type(input) != common_type: broadcasting_needed = True if common_shape is None: common_shape = input.shape elif input.shape != common_shape: broadcasting_needed = True new_inputs.append(input) inputs = new_inputs if broadcasting_needed: if is_pandas: if wrapper is None: inputs, wrapper = broadcast( dict(zip(input_names, inputs)), return_wrapper=True, **broadcast_kwargs, ) else: inputs = broadcast(dict(zip(input_names, inputs)), **broadcast_kwargs) inputs = [inputs[k].values for k in input_names] else: inputs = broadcast_arrays(*inputs) else: if is_pandas: if wrapper is None: wrapper = ArrayWrapper.from_obj(inputs[0]) inputs = [input.values for input in inputs] input_shape = inputs[0].shape def _run_talib_func(inputs, *_args, **_kwargs): target_index = None if timeframe is not None: if wrapper is None: raise ValueError("Resampling requires a wrapper") if wrapper.freq is None: if not silence_warnings: warn( "Couldn't parse the frequency of index. " "Set freq in wrapper_kwargs via broadcast_kwargs, or globally." ) new_inputs = () _resample_map = merge_dicts( resample_map, { "open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum", }, ) source_wrapper = ArrayWrapper(index=wrapper.index, freq=wrapper.freq) for i, input in enumerate(inputs): _resample_kwargs = resolve_dict(resample_kwargs, i=i) new_input = GenericAccessor(source_wrapper, input).resample_apply( timeframe, _resample_map[input_names[i]], **_resample_kwargs, ) target_index = new_input.index new_inputs += (new_input.values,) inputs = new_inputs def _build_nan_outputs(): nan_outputs = [] for i in range(len(output_names)): nan_outputs.append(np.full(input_shape, np.nan, dtype=np.double)) if len(nan_outputs) == 1: return nan_outputs[0] return nan_outputs all_nan = False if skipna: nan_mask = build_nan_mask(*inputs) if nan_mask.all(): all_nan = True else: inputs = squeeze_nan(*inputs, nan_mask=nan_mask) else: nan_mask = None if all_nan: outputs = _build_nan_outputs() else: inputs = tuple([arr.astype(np.double) for arr in inputs]) try: outputs = talib_func(*inputs, *_args, **_kwargs) except Exception as e: if "inputs are all NaN" in str(e): outputs = _build_nan_outputs() all_nan = True else: raise e if not all_nan: if one_output: outputs = unsqueeze_nan(outputs, nan_mask=nan_mask) else: outputs = unsqueeze_nan(*outputs, nan_mask=nan_mask) if timeframe is not None: new_outputs = () target_wrapper = ArrayWrapper(index=target_index) for i, output in enumerate(outputs): _realign_kwargs = merge_dicts( dict( source_rbound=True, target_rbound=True, nan_value=np.nan, ffill=True, silence_warnings=True, ), resolve_dict(realign_kwargs, i=i), ) new_output = GenericAccessor(target_wrapper, output).realign( wrapper.index, freq=wrapper.freq, **_realign_kwargs, ) new_outputs += (new_output.values,) outputs = new_outputs return outputs if inputs[0].ndim == 1: outputs = _run_talib_func(inputs, *other_args, **kwargs) else: outputs = [] for col in range(inputs[0].shape[1]): col_inputs = [input[:, col] for input in inputs] col_outputs = _run_talib_func(col_inputs, *other_args, **kwargs) outputs.append(col_outputs) outputs = list(zip(*outputs)) outputs = tuple(map(column_stack_arrays, outputs)) if wrap is None: wrap = is_pandas if wrap: outputs = [wrapper.wrap(output, **wrap_kwargs) for output in outputs] if unpack_to is not None: if unpack_to.lower() in ("dict", "frame"): dct = {name: outputs[i] for i, name in enumerate(output_names)} if unpack_to.lower() == "dict": return dct return pd.concat(list(dct.values()), axis=1, keys=pd.Index(list(dct.keys()), name="output")) raise ValueError(f"Invalid unpack_to: '{unpack_to}'") if one_output: return outputs[0] return outputs signature = inspect.signature(run_talib_func) new_parameters = list(signature.parameters.values())[1:] k = 0 for input_name in input_names: new_parameters.insert( k, inspect.Parameter( input_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=tp.ArrayLike, ), ) k += 1 for param_name in param_names: new_parameters.insert( k, inspect.Parameter( param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=info["parameters"][param_name], annotation=tp.Scalar, ), ) k += 1 run_talib_func.__signature__ = signature.replace(parameters=new_parameters) run_talib_func.__name__ = "run_" + func_name.lower() run_talib_func.__qualname__ = run_talib_func.__name__ run_talib_func.__doc__ = f"""Run `talib.{func_name}` on NumPy arrays, Series, and DataFrames. Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed. Set `timeframe` to a frequency to resample the input arrays to this frequency, run the function, and then resample the output arrays back to the original frequency. Optionally, provide `resample_map` as a dictionary that maps input names to resample-apply function names. Keyword arguments `resample_kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.resample_apply` while `realign_kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.realign`. Both can be also provided as sequences of dictionaries - one dictionary per input and output respectively. Set `skipna` to True to run the TA-Lib function on non-NA values only. Broadcasts the input arrays if they have different types or shapes. If one of the input arrays is a Series/DataFrame, wraps the output arrays into a Pandas format. To enable or disable wrapping, set `wrap` to True and False respectively.""" return run_talib_func def talib_plot_func(func_name: str) -> tp.Callable: """Get the TA-Lib indicator plotting function.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("talib") from talib import abstract from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] func_name = func_name.upper() info = abstract.Function(func_name).info output_names = info["output_names"] output_flags = info["output_flags"] def run_talib_plot_func( *outputs, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, column: tp.Optional[tp.Label] = None, limits: tp.Optional[tp.Tuple[float, float]] = None, add_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.BaseFigure: if wrap_kwargs is None: wrap_kwargs = {} new_outputs = [] for output in outputs: if not isinstance(output, (pd.Series, pd.DataFrame)): if wrapper is not None: output = wrapper.wrap(output, **wrap_kwargs) else: output = to_pd_array(output) if wrapper is not None: output = Wrapping.select_col_from_obj( output, column=column, wrapper=wrapper, ) else: output = Wrapping.select_col_from_obj( output, column=column, wrapper=ArrayWrapper.from_obj(output), ) new_outputs.append(output) outputs = dict(zip(output_names, new_outputs)) output_trace_kwargs = {} for output_name in output_names: output_trace_kwargs[output_name] = kwargs.pop(output_name + "_trace_kwargs", {}) priority_outputs = [] other_outputs = [] for output_name in output_names: flags = set(output_flags.get(output_name)) found_priority = False if abstract.TA_OUTPUT_FLAGS[2048] in flags: priority_outputs = priority_outputs + [output_name] found_priority = True if abstract.TA_OUTPUT_FLAGS[4096] in flags: priority_outputs = [output_name] + priority_outputs found_priority = True if not found_priority: other_outputs.append(output_name) for output_name in priority_outputs + other_outputs: output = outputs[output_name].rename(output_name) flags = set(output_flags.get(output_name)) trace_kwargs = {} plot_func_name = "lineplot" if abstract.TA_OUTPUT_FLAGS[2] in flags: # Dotted Line if "line" not in trace_kwargs: trace_kwargs["line"] = dict() trace_kwargs["line"]["dash"] = "dashdot" if abstract.TA_OUTPUT_FLAGS[4] in flags: # Dashed Line if "line" not in trace_kwargs: trace_kwargs["line"] = dict() trace_kwargs["line"]["dash"] = "dash" if abstract.TA_OUTPUT_FLAGS[8] in flags: # Dot if "line" not in trace_kwargs: trace_kwargs["line"] = dict() trace_kwargs["line"]["dash"] = "dot" if abstract.TA_OUTPUT_FLAGS[16] in flags: # Histogram hist = np.asarray(output) hist_diff = generic_nb.diff_1d_nb(hist) marker_colors = np.full(hist.shape, adjust_opacity("silver", 0.75), dtype=object) marker_colors[(hist > 0) & (hist_diff > 0)] = adjust_opacity("green", 0.75) marker_colors[(hist > 0) & (hist_diff <= 0)] = adjust_opacity("lightgreen", 0.75) marker_colors[(hist < 0) & (hist_diff < 0)] = adjust_opacity("red", 0.75) marker_colors[(hist < 0) & (hist_diff >= 0)] = adjust_opacity("lightcoral", 0.75) if "marker" not in trace_kwargs: trace_kwargs["marker"] = {} trace_kwargs["marker"]["color"] = marker_colors if "line" not in trace_kwargs["marker"]: trace_kwargs["marker"]["line"] = {} trace_kwargs["marker"]["line"]["width"] = 0 kwargs["bargap"] = 0 plot_func_name = "barplot" if abstract.TA_OUTPUT_FLAGS[2048] in flags: # Values represent an upper limit if "line" not in trace_kwargs: trace_kwargs["line"] = {} trace_kwargs["line"]["color"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.75) trace_kwargs["fill"] = "tonexty" trace_kwargs["fillcolor"] = "rgba(128, 128, 128, 0.2)" if abstract.TA_OUTPUT_FLAGS[4096] in flags: # Values represent a lower limit if "line" not in trace_kwargs: trace_kwargs["line"] = {} trace_kwargs["line"]["color"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.75) trace_kwargs = merge_dicts(trace_kwargs, output_trace_kwargs[output_name]) plot_func = getattr(output.vbt, plot_func_name) fig = plot_func(trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **kwargs) if limits is not None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] add_shape_kwargs = merge_dicts( dict( type="rect", xref=xref, yref=yref, x0=outputs[output_names[0]].index[0], y0=limits[0], x1=outputs[output_names[0]].index[-1], y1=limits[1], fillcolor="mediumslateblue", opacity=0.2, layer="below", line_width=0, ), add_shape_kwargs, ) fig.add_shape(**add_shape_kwargs) return fig signature = inspect.signature(run_talib_plot_func) new_parameters = list(signature.parameters.values())[1:-1] k = 0 for output_name in output_names: new_parameters.insert( k, inspect.Parameter( output_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=tp.ArrayLike, ), ) k += 1 for output_name in output_names: new_parameters.insert( -3, inspect.Parameter( output_name + "_trace_kwargs", inspect.Parameter.KEYWORD_ONLY, default=None, annotation=tp.KwargsLike, ), ) new_parameters.append(inspect.Parameter("layout_kwargs", inspect.Parameter.VAR_KEYWORD)) run_talib_plot_func.__signature__ = signature.replace(parameters=new_parameters) output_trace_kwargs_docstring = "\n ".join( [ f"{output_name}_trace_kwargs (dict): Keyword arguments passed to the trace of `{output_name}`." for output_name in output_names ] ) run_talib_plot_func.__name__ = "plot_" + func_name.lower() run_talib_plot_func.__qualname__ = run_talib_plot_func.__name__ run_talib_plot_func.__doc__ = f"""Plot output arrays of `talib.{func_name}`. Args: column (str): Name of the column to plot. limits (tuple of float): Tuple of the lower and upper limit. {output_trace_kwargs_docstring} add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both limits. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add the traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`.""" return run_talib_plot_func # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Basic look-ahead indicators and label generators. You can access all the indicators by `vbt.*`. Run for the examples: ```pycon >>> ohlcv = vbt.YFData.pull( ... "BTC-USD", ... start="2019-03-01", ... end="2019-09-01" ... ).get() ``` """ from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.labels.generators.bolb import * from vectorbtpro.labels.generators.fixlb import * from vectorbtpro.labels.generators.fmax import * from vectorbtpro.labels.generators.fmean import * from vectorbtpro.labels.generators.fmin import * from vectorbtpro.labels.generators.fstd import * from vectorbtpro.labels.generators.pivotlb import * from vectorbtpro.labels.generators.meanlb import * from vectorbtpro.labels.generators.trendlb import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `BOLB`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb __all__ = [ "BOLB", ] __pdoc__ = {} BOLB = IndicatorFactory( class_name="BOLB", module_name=__name__, input_names=["high", "low"], param_names=["window", "up_th", "down_th", "wait"], output_names=["labels"], ).with_apply_func( nb.breakout_labels_nb, param_settings=dict( up_th=flex_elem_param_config, down_th=flex_elem_param_config, ), window=14, up_th=np.inf, down_th=np.inf, wait=1, ) class _BOLB(BOLB): """Label generator based on `vectorbtpro.labels.nb.breakout_labels_nb`.""" def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure: """Plot the median of `BOLB.high` and `BOLB.low`, and overlay it with the heatmap of `BOLB.labels`. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`. Usage: ```pycon >>> vbt.BOLB.run(ohlcv['High'], ohlcv['Low'], up_th=0.2, down_th=0.2).plot().show() ``` ![](/assets/images/api/BOLB.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/BOLB.dark.svg#only-dark){: .iimg loading=lazy } """ self_col = self.select_col(column=column, group_by=False) median = (self_col.high + self_col.low) / 2 return median.rename("Median").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs) setattr(BOLB, "__doc__", _BOLB.__doc__) setattr(BOLB, "plot", _BOLB.plot) BOLB.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FIXLB`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb __all__ = [ "FIXLB", ] __pdoc__ = {} FIXLB = IndicatorFactory( class_name="FIXLB", module_name=__name__, input_names=["close"], param_names=["n"], output_names=["labels"], ).with_apply_func( nb.fixed_labels_nb, n=1, ) class _FIXLB(FIXLB): """Label generator based on `vectorbtpro.labels.nb.fixed_labels_nb`.""" def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure: """Plot `FIXLB.close` and overlay it with the heatmap of `FIXLB.labels`. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`. Usage: ```pycon >>> vbt.FIXLB.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/FIXLB.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/FIXLB.dark.svg#only-dark){: .iimg loading=lazy } """ self_col = self.select_col(column=column, group_by=False) return self_col.close.rename("Close").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs) setattr(FIXLB, "__doc__", _FIXLB.__doc__) setattr(FIXLB, "plot", _FIXLB.plot) FIXLB.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FMAX`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb from vectorbtpro.utils.config import merge_dicts __all__ = [ "FMAX", ] __pdoc__ = {} FMAX = IndicatorFactory( class_name="FMAX", module_name=__name__, input_names=["close"], param_names=["window", "wait"], output_names=["fmax"], ).with_apply_func( nb.future_max_nb, window=14, wait=1, ) class _FMAX(FMAX): """Look-ahead indicator based on `vectorbtpro.labels.nb.future_max_nb`.""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, fmax_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `FMAX.fmax` against `FMAX.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `FMAX.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMAX.close`. fmax_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMAX.fmax`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.FMAX.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/FMAX.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/FMAX.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if fmax_trace_kwargs is None: fmax_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) fmax_trace_kwargs = merge_dicts( dict(name="Future max", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), fmax_trace_kwargs, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.fmax.vbt.lineplot( trace_kwargs=fmax_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(FMAX, "__doc__", _FMAX.__doc__) setattr(FMAX, "plot", _FMAX.plot) FMAX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FMEAN`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb from vectorbtpro.utils.config import merge_dicts __all__ = [ "FMEAN", ] __pdoc__ = {} FMEAN = IndicatorFactory( class_name="FMEAN", module_name=__name__, input_names=["close"], param_names=["window", "wtype", "wait"], output_names=["fmean"], ).with_apply_func( nb.future_mean_nb, kwargs_as_args=["minp", "adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="simple", wait=1, minp=None, adjust=False, ) class _FMEAN(FMEAN): """Look-ahead indicator based on `vectorbtpro.labels.nb.future_mean_nb`.""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, fmean_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `FMEAN.fmean` against `FMEAN.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `FMEAN.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMEAN.close`. fmean_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMEAN.fmean`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.FMEAN.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/FMEAN.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/FMEAN.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if fmean_trace_kwargs is None: fmean_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) fmean_trace_kwargs = merge_dicts( dict(name="Future mean", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), fmean_trace_kwargs, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.fmean.vbt.lineplot( trace_kwargs=fmean_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(FMEAN, "__doc__", _FMEAN.__doc__) setattr(FMEAN, "plot", _FMEAN.plot) FMEAN.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FMIN`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb from vectorbtpro.utils.config import merge_dicts __all__ = [ "FMIN", ] __pdoc__ = {} FMIN = IndicatorFactory( class_name="FMIN", module_name=__name__, input_names=["close"], param_names=["window", "wait"], output_names=["fmin"], ).with_apply_func( nb.future_min_nb, window=14, wait=1, ) class _FMIN(FMIN): """Look-ahead indicator based on `vectorbtpro.labels.nb.future_min_nb`.""" def plot( self, column: tp.Optional[tp.Label] = None, plot_close: bool = True, close_trace_kwargs: tp.KwargsLike = None, fmin_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `FMIN.fmin` against `FMIN.close`. Args: column (str): Name of the column to plot. plot_close (bool): Whether to plot `FMIN.close`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMIN.close`. fmin_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMIN.fmin`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.FMIN.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/FMIN.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/FMIN.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if close_trace_kwargs is None: close_trace_kwargs = {} if fmin_trace_kwargs is None: fmin_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])), close_trace_kwargs, ) fmin_trace_kwargs = merge_dicts( dict(name="Future min", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), fmin_trace_kwargs, ) if plot_close: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) fig = self_col.fmin.vbt.lineplot( trace_kwargs=fmin_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(FMIN, "__doc__", _FMIN.__doc__) setattr(FMIN, "plot", _FMIN.plot) FMIN.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `FSTD`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb from vectorbtpro.utils.config import merge_dicts __all__ = [ "FSTD", ] __pdoc__ = {} FSTD = IndicatorFactory( class_name="FSTD", module_name=__name__, input_names=["close"], param_names=["window", "wtype", "wait"], output_names=["fstd"], ).with_apply_func( nb.future_std_nb, kwargs_as_args=["minp", "adjust", "ddof"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="simple", wait=1, minp=None, adjust=False, ddof=0, ) class _FSTD(FSTD): """Look-ahead indicator based on `vectorbtpro.labels.nb.future_std_nb`.""" def plot( self, column: tp.Optional[tp.Label] = None, fstd_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot `FSTD.fstd`. Args: column (str): Name of the column to plot. fstd_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FSTD.fstd`. add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments passed to `fig.update_layout`. Usage: ```pycon >>> vbt.FSTD.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/FSTD.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/FSTD.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if fstd_trace_kwargs is None: fstd_trace_kwargs = {} fstd_trace_kwargs = merge_dicts( dict(name="Future STD", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), fstd_trace_kwargs, ) fig = self_col.fstd.vbt.lineplot( trace_kwargs=fstd_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) return fig setattr(FSTD, "__doc__", _FSTD.__doc__) setattr(FSTD, "plot", _FSTD.plot) FSTD.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `MEANLB`.""" from vectorbtpro import _typing as tp from vectorbtpro.generic import enums as generic_enums from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb __all__ = [ "MEANLB", ] __pdoc__ = {} MEANLB = IndicatorFactory( class_name="MEANLB", module_name=__name__, input_names=["close"], param_names=["window", "wtype", "wait"], output_names=["labels"], ).with_apply_func( nb.mean_labels_nb, kwargs_as_args=["minp", "adjust"], param_settings=dict( wtype=dict( dtype=generic_enums.WType, post_index_func=lambda index: index.str.lower(), ) ), window=14, wtype="simple", wait=1, minp=None, adjust=False, ) class _MEANLB(MEANLB): """Label generator based on `vectorbtpro.labels.nb.mean_labels_nb`.""" def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure: """Plot `close` and overlay it with the heatmap of `MEANLB.labels`. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`. Usage: ```pycon >>> vbt.MEANLB.run(ohlcv['Close']).plot().show() ``` ![](/assets/images/api/MEANLB.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/MEANLB.dark.svg#only-dark){: .iimg loading=lazy } """ self_col = self.select_col(column=column, group_by=False) return self_col.close.rename("Close").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs) setattr(MEANLB, "__doc__", _MEANLB.__doc__) setattr(MEANLB, "plot", _MEANLB.plot) MEANLB.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `PIVOTLB`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb __all__ = [ "PIVOTLB", ] __pdoc__ = {} PIVOTLB = IndicatorFactory( class_name="PIVOTLB", module_name=__name__, input_names=["high", "low"], param_names=["up_th", "down_th"], output_names=["labels"], ).with_apply_func( nb.pivots_nb, param_settings=dict( up_th=flex_elem_param_config, down_th=flex_elem_param_config, ), ) class _PIVOTLB(PIVOTLB): """Label generator based on `vectorbtpro.labels.nb.pivots_nb`.""" def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure: """Plot the median of `PIVOTLB.high` and `PIVOTLB.low`, and overlay it with the heatmap of `PIVOTLB.labels`. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`. Usage: ```pycon >>> vbt.PIVOTLB.run(ohlcv['High'], ohlcv['Low'], up_th=0.2, down_th=0.2).plot().show() ``` ![](/assets/images/api/PIVOTLB.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/PIVOTLB.dark.svg#only-dark){: .iimg loading=lazy } """ self_col = self.select_col(column=column, group_by=False) median = (self_col.high + self_col.low) / 2 return median.rename("Median").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs) setattr(PIVOTLB, "__doc__", _PIVOTLB.__doc__) setattr(PIVOTLB, "plot", _PIVOTLB.plot) PIVOTLB.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `TRENDLB`.""" from vectorbtpro import _typing as tp from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.indicators.factory import IndicatorFactory from vectorbtpro.labels import nb from vectorbtpro.labels.enums import TrendLabelMode __all__ = [ "TRENDLB", ] __pdoc__ = {} TRENDLB = IndicatorFactory( class_name="TRENDLB", module_name=__name__, input_names=["high", "low"], param_names=["up_th", "down_th", "mode"], output_names=["labels"], ).with_apply_func( nb.trend_labels_nb, param_settings=dict( up_th=flex_elem_param_config, down_th=flex_elem_param_config, mode=dict( dtype=TrendLabelMode, post_index_func=lambda index: index.str.lower(), ), ), mode=TrendLabelMode.Binary, ) class _TRENDLB(TRENDLB): """Label generator based on `vectorbtpro.labels.nb.trend_labels_nb`.""" def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure: """Plot the median of `TRENDLB.high` and `TRENDLB.low`, and overlay it with the heatmap of `TRENDLB.labels`. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`. Usage: ```pycon >>> vbt.TRENDLB.run(ohlcv['High'], ohlcv['Low'], up_th=0.2, down_th=0.2).plot().show() ``` ![](/assets/images/api/TRENDLB.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/TRENDLB.dark.svg#only-dark){: .iimg loading=lazy } """ self_col = self.select_col(column=column, group_by=False) median = (self_col.high + self_col.low) / 2 return median.rename("Median").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs) setattr(TRENDLB, "__doc__", _TRENDLB.__doc__) setattr(TRENDLB, "plot", _TRENDLB.plot) TRENDLB.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for building and running look-ahead indicators and label generators.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.labels.generators import * from vectorbtpro.labels.nb import * __exclude_from__all__ = [ "enums", ] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for label generation. Defines enums and other schemas for `vectorbtpro.labels`.""" from vectorbtpro import _typing as tp from vectorbtpro.utils.formatting import prettify __pdoc__all__ = __all__ = ["TrendLabelMode"] __pdoc__ = {} class TrendLabelModeT(tp.NamedTuple): Binary: int = 0 BinaryCont: int = 1 BinaryContSat: int = 2 PctChange: int = 3 PctChangeNorm: int = 4 TrendLabelMode = TrendLabelModeT() """_""" __pdoc__[ "TrendLabelMode" ] = f"""Trend label mode. ```python {prettify(TrendLabelMode)} ``` Attributes: Binary: See `vectorbtpro.labels.nb.bin_trend_labels_nb`. BinaryCont: See `vectorbtpro.labels.nb.binc_trend_labels_nb`. BinaryContSat: See `vectorbtpro.labels.nb.bincs_trend_labels_nb`. PctChange: See `vectorbtpro.labels.nb.pct_trend_labels_nb` with `normalize` set to False. PctChangeNorm: See `vectorbtpro.labels.nb.pct_trend_labels_nb` with `normalize` set to True. """ # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for label generation. !!! note Set `wait` to 1 to exclude the current value from calculation of future values. !!! warning Do not attempt to use these functions for building predictor variables as they may introduce the look-ahead bias to your model - only use for building target variables.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_nb, flex_select_col_nb from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.generic import nb as generic_nb, enums as generic_enums from vectorbtpro.indicators.enums import Pivot from vectorbtpro.labels.enums import TrendLabelMode from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch __all__ = [] # ############# FMEAN ############# # @register_jitted(cache=True) def future_mean_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Simple, wait: int = 1, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """Rolling average over future values. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" future_mean = generic_nb.ma_1d_nb(close[::-1], window, wtype=wtype, minp=minp, adjust=adjust)[::-1] if wait > 0: return generic_nb.bshift_1d_nb(future_mean, wait) return future_mean @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), wait=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def future_mean_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, wait: tp.FlexArray1dLike = 1, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `future_mean_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) wait_ = to_1d_array_nb(np.asarray(wait)) future_mean = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): future_mean[:, col] = future_mean_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), wait=flex_select_1d_nb(wait_, col), minp=minp, adjust=adjust, ) return future_mean # ############# FSTD ############# # @register_jitted(cache=True) def future_std_1d_nb( close: tp.Array1d, window: int = 14, wtype: int = generic_enums.WType.Simple, wait: int = 1, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Array1d: """Rolling standard deviation over future values. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" future_std = generic_nb.msd_1d_nb(close[::-1], window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof)[::-1] if wait > 0: return generic_nb.bshift_1d_nb(future_std, wait) return future_std @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), wait=base_ch.FlexArraySlicer(), minp=None, adjust=None, ddof=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def future_std_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, wait: tp.FlexArray1dLike = 1, minp: tp.Optional[int] = None, adjust: bool = False, ddof: int = 0, ) -> tp.Array2d: """2-dim version of `future_std_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) wait_ = to_1d_array_nb(np.asarray(wait)) future_std = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): future_std[:, col] = future_std_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), wait=flex_select_1d_nb(wait_, col), minp=minp, adjust=adjust, ddof=ddof, ) return future_std # ############# FMIN ############# # @register_jitted(cache=True) def future_min_1d_nb( close: tp.Array1d, window: int = 14, wait: int = 1, minp: tp.Optional[int] = None, ) -> tp.Array1d: """Rolling minimum over future values.""" future_min = generic_nb.rolling_min_1d_nb(close[::-1], window, minp=minp)[::-1] if wait > 0: return generic_nb.bshift_1d_nb(future_min, wait) return future_min @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wait=base_ch.FlexArraySlicer(), minp=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def future_min_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wait: tp.FlexArray1dLike = 1, minp: tp.Optional[int] = None, ) -> tp.Array2d: """2-dim version of `future_min_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wait_ = to_1d_array_nb(np.asarray(wait)) future_min = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): future_min[:, col] = future_min_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wait=flex_select_1d_nb(wait_, col), minp=minp, ) return future_min # ############# FMAX ############# # @register_jitted(cache=True) def future_max_1d_nb( close: tp.Array1d, window: int = 14, wait: int = 1, minp: tp.Optional[int] = None, ) -> tp.Array1d: """Rolling maximum over future values.""" future_max = generic_nb.rolling_max_1d_nb(close[::-1], window, minp=minp)[::-1] if wait > 0: return generic_nb.bshift_1d_nb(future_max, wait) return future_max @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wait=base_ch.FlexArraySlicer(), minp=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def future_max_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wait: tp.FlexArray1dLike = 1, minp: tp.Optional[int] = None, ) -> tp.Array2d: """2-dim version of `future_max_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wait_ = to_1d_array_nb(np.asarray(wait)) future_max = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): future_max[:, col] = future_max_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wait=flex_select_1d_nb(wait_, col), minp=minp, ) return future_max # ############# FIXLB ############# # @register_jitted(cache=True) def fixed_labels_1d_nb( close: tp.Array1d, n: int = 1, ) -> tp.Array1d: """Percentage change of the current value relative to a future value.""" return (generic_nb.bshift_1d_nb(close, n) - close) / close @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), n=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def fixed_labels_nb( close: tp.Array2d, n: tp.FlexArray1dLike = 1, ) -> tp.Array2d: """2-dim version of `fixed_labels_1d_nb`.""" n_ = to_1d_array_nb(np.asarray(n)) fixed_labels = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): fixed_labels[:, col] = fixed_labels_1d_nb( close=close[:, col], n=flex_select_1d_nb(n_, col), ) return fixed_labels # ############# MEANLB ############# # @register_jitted(cache=True) def mean_labels_1d_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, wait: tp.FlexArray1dLike = 1, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array1d: """Percentage change of the current value relative to the average of a future period. For `wtype`, see `vectorbtpro.generic.enums.WType`.""" future_mean = future_mean_1d_nb(close, window=window, wtype=wtype, wait=wait, minp=minp, adjust=adjust) return (future_mean - close) / close @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), wtype=base_ch.FlexArraySlicer(), wait=base_ch.FlexArraySlicer(), minp=None, adjust=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def mean_labels_nb( close: tp.Array2d, window: tp.FlexArray1dLike = 14, wtype: tp.FlexArray1dLike = generic_enums.WType.Simple, wait: tp.FlexArray1dLike = 1, minp: tp.Optional[int] = None, adjust: bool = False, ) -> tp.Array2d: """2-dim version of `mean_labels_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) wtype_ = to_1d_array_nb(np.asarray(wtype)) wait_ = to_1d_array_nb(np.asarray(wait)) mean_labels = np.empty(close.shape, dtype=float_) for col in prange(close.shape[1]): mean_labels[:, col] = mean_labels_1d_nb( close=close[:, col], window=flex_select_1d_nb(window_, col), wtype=flex_select_1d_nb(wtype_, col), wait=flex_select_1d_nb(wait_, col), minp=minp, adjust=adjust, ) return mean_labels # ############# PIVOTLB ############# # @register_jitted(cache=True) def iter_symmetric_up_th_nb(down_th: float) -> float: """Positive upper threshold that is symmetric to a negative one at one iteration. For example, 50% down requires 100% to go up to the initial level.""" return down_th / (1 - down_th) @register_jitted(cache=True) def iter_symmetric_down_th_nb(up_th: float) -> float: """Negative upper threshold that is symmetric to a positive one at one iteration.""" return up_th / (1 + up_th) @register_jitted(cache=True) def pivots_1d_nb( high: tp.Array1d, low: tp.Array1d, up_th: tp.FlexArray1dLike, down_th: tp.FlexArray1dLike, ) -> tp.Array1d: """Pivots denoted by 1 (peak), 0 (no pivot) or -1 (valley). Two adjacent peak and valley points should exceed the given threshold parameters. If any threshold is given element-wise, it will be applied per new/updated pivot.""" up_th_ = to_1d_array_nb(np.asarray(up_th)) down_th_ = to_1d_array_nb(np.asarray(down_th)) pivots = np.full(high.shape, 0, dtype=int_) last_pivot = 0 last_i = -1 last_value = np.nan first_valid_i = -1 for i in range(high.shape[0]): if not np.isnan(high[i]) and not np.isnan(low[i]): if first_valid_i == -1: first_valid_i = 0 if last_i == -1: _up_th = 1 + abs(flex_select_1d_nb(up_th_, first_valid_i)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, first_valid_i)) if not np.isnan(_up_th) and high[i] >= low[first_valid_i] * _up_th: if not np.isnan(_down_th) and low[i] <= high[first_valid_i] * _down_th: pass # wait else: pivots[first_valid_i] = Pivot.Valley last_i = i last_value = high[i] last_pivot = Pivot.Peak if not np.isnan(_down_th) and low[i] <= high[first_valid_i] * _down_th: if not np.isnan(_up_th) and high[i] >= low[first_valid_i] * _up_th: pass # wait else: pivots[first_valid_i] = Pivot.Peak last_i = i last_value = low[i] last_pivot = Pivot.Valley else: _up_th = 1 + abs(flex_select_1d_nb(up_th_, last_i)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, last_i)) if last_pivot == Pivot.Valley: if not np.isnan(last_value) and not np.isnan(_up_th) and high[i] >= last_value * _up_th: pivots[last_i] = last_pivot last_i = i last_value = high[i] last_pivot = Pivot.Peak elif np.isnan(last_value) or low[i] < last_value: last_i = i last_value = low[i] elif last_pivot == Pivot.Peak: if not np.isnan(last_value) and not np.isnan(_down_th) and low[i] <= last_value * _down_th: pivots[last_i] = last_pivot last_i = i last_value = low[i] last_pivot = Pivot.Valley elif np.isnan(last_value) or high[i] > last_value: last_i = i last_value = high[i] if last_i != -1 and i == high.shape[0] - 1: pivots[last_i] = last_pivot return pivots @register_chunkable( size=ch.ArraySizer(arg_query="high", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), up_th=base_ch.FlexArraySlicer(axis=1), down_th=base_ch.FlexArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def pivots_nb( high: tp.Array2d, low: tp.Array2d, up_th: tp.FlexArray2dLike, down_th: tp.FlexArray2dLike, ) -> tp.Array2d: """2-dim version of `pivots_1d_nb`.""" up_th_ = to_2d_array_nb(np.asarray(up_th)) down_th_ = to_2d_array_nb(np.asarray(down_th)) pivots = np.empty(high.shape, dtype=int_) for col in prange(high.shape[1]): pivots[:, col] = pivots_1d_nb( high[:, col], low[:, col], flex_select_col_nb(up_th_, col), flex_select_col_nb(down_th_, col), ) return pivots # ############# TRENDLB ############# # @register_jitted(cache=True) def bin_trend_labels_1d_nb(pivots: tp.Array1d) -> tp.Array1d: """Values classified into 0 (downtrend) and 1 (uptrend).""" bin_trend_labels = np.full(pivots.shape, np.nan, dtype=float_) idxs = np.flatnonzero(pivots) if idxs.shape[0] == 0: return bin_trend_labels for k in range(1, idxs.shape[0]): prev_i = idxs[k - 1] next_i = idxs[k] for i in range(prev_i, next_i): if pivots[next_i] == Pivot.Peak: bin_trend_labels[i] = 1 else: bin_trend_labels[i] = 0 return bin_trend_labels @register_chunkable( size=ch.ArraySizer(arg_query="pivots", axis=1), arg_take_spec=dict( pivots=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def bin_trend_labels_nb(pivots: tp.Array2d) -> tp.Array2d: """2-dim version of `bin_trend_labels_1d_nb`.""" bin_trend_labels = np.empty(pivots.shape, dtype=float_) for col in prange(pivots.shape[1]): bin_trend_labels[:, col] = bin_trend_labels_1d_nb(pivots[:, col]) return bin_trend_labels @register_jitted(cache=True) def binc_trend_labels_1d_nb(high: tp.Array1d, low: tp.Array1d, pivots: tp.Array1d) -> tp.Array1d: """Median values normalized between 0 (downtrend) and 1 (uptrend).""" binc_trend_labels = np.full(pivots.shape, np.nan, dtype=float_) idxs = np.flatnonzero(pivots[:]) if idxs.shape[0] == 0: return binc_trend_labels for k in range(1, idxs.shape[0]): prev_i = idxs[k - 1] next_i = idxs[k] _min = np.nanmin(low[prev_i : next_i + 1]) _max = np.nanmax(high[prev_i : next_i + 1]) for i in range(prev_i, next_i): _med = (high[i] + low[i]) / 2 binc_trend_labels[i] = 1 - (_med - _min) / (_max - _min) return binc_trend_labels @register_chunkable( size=ch.ArraySizer(arg_query="pivots", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), pivots=ch.ArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def binc_trend_labels_nb(high: tp.Array2d, low: tp.Array2d, pivots: tp.Array2d) -> tp.Array2d: """2-dim version of `binc_trend_labels_1d_nb`.""" binc_trend_labels = np.empty(pivots.shape, dtype=float_) for col in prange(pivots.shape[1]): binc_trend_labels[:, col] = binc_trend_labels_1d_nb(high[:, col], low[:, col], pivots[:, col]) return binc_trend_labels @register_jitted(cache=True) def bincs_trend_labels_1d_nb( high: tp.Array1d, low: tp.Array1d, pivots: tp.Array1d, up_th: tp.FlexArray1dLike, down_th: tp.FlexArray1dLike, ) -> tp.Array1d: """Median values normalized between 0 (downtrend) and 1 (uptrend) but capped once the threshold defined at the beginning of the trend is exceeded.""" up_th_ = to_1d_array_nb(np.asarray(up_th)) down_th_ = to_1d_array_nb(np.asarray(down_th)) bincs_trend_labels = np.full(pivots.shape, np.nan, dtype=float_) idxs = np.flatnonzero(pivots) if idxs.shape[0] == 0: return bincs_trend_labels for k in range(1, idxs.shape[0]): prev_i = idxs[k - 1] next_i = idxs[k] _up_th = 1 + abs(flex_select_1d_nb(up_th_, prev_i)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, prev_i)) _min = np.min(low[prev_i : next_i + 1]) _max = np.max(high[prev_i : next_i + 1]) for i in range(prev_i, next_i): if not np.isnan(high[i]) and not np.isnan(low[i]): _med = (high[i] + low[i]) / 2 if pivots[next_i] == Pivot.Peak: if not np.isnan(_up_th): _start = _max / _up_th _end = _min * _up_th if _max >= _end and _med <= _start: bincs_trend_labels[i] = 1 else: bincs_trend_labels[i] = 1 - (_med - _start) / (_max - _start) else: if not np.isnan(_down_th): _start = _min / _down_th _end = _max * _down_th if _min <= _end and _med >= _start: bincs_trend_labels[i] = 0 else: bincs_trend_labels[i] = 1 - (_med - _min) / (_start - _min) return bincs_trend_labels @register_chunkable( size=ch.ArraySizer(arg_query="pivots", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), pivots=ch.ArraySlicer(axis=1), up_th=base_ch.FlexArraySlicer(axis=1), down_th=base_ch.FlexArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def bincs_trend_labels_nb( high: tp.Array2d, low: tp.Array2d, pivots: tp.Array2d, up_th: tp.FlexArray2dLike, down_th: tp.FlexArray2dLike, ) -> tp.Array2d: """2-dim version of `bincs_trend_labels_1d_nb`.""" up_th_ = to_2d_array_nb(np.asarray(up_th)) down_th_ = to_2d_array_nb(np.asarray(down_th)) bincs_trend_labels = np.empty(pivots.shape, dtype=float_) for col in prange(pivots.shape[1]): bincs_trend_labels[:, col] = bincs_trend_labels_1d_nb( high[:, col], low[:, col], pivots[:, col], flex_select_col_nb(up_th_, col), flex_select_col_nb(down_th_, col), ) return bincs_trend_labels @register_jitted(cache=True) def pct_trend_labels_1d_nb( high: tp.Array1d, low: tp.Array1d, pivots: tp.Array1d, normalize: bool = False, ) -> tp.Array1d: """Percentage change of median values relative to the next pivot.""" pct_trend_labels = np.full(pivots.shape, np.nan, dtype=float_) idxs = np.flatnonzero(pivots) if idxs.shape[0] == 0: return pct_trend_labels for k in range(1, idxs.shape[0]): prev_i = idxs[k - 1] next_i = idxs[k] for i in range(prev_i, next_i): _med = (high[i] + low[i]) / 2 if pivots[next_i] == Pivot.Peak: if normalize: pct_trend_labels[i] = (high[next_i] - _med) / high[next_i] else: pct_trend_labels[i] = (high[next_i] - _med) / _med else: if normalize: pct_trend_labels[i] = (low[next_i] - _med) / _med else: pct_trend_labels[i] = (low[next_i] - _med) / low[next_i] return pct_trend_labels @register_chunkable( size=ch.ArraySizer(arg_query="pivots", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), pivots=ch.ArraySlicer(axis=1), normalize=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def pct_trend_labels_nb( high: tp.Array2d, low: tp.Array2d, pivots: tp.Array2d, normalize: bool = False, ) -> tp.Array2d: """2-dim version of `pct_trend_labels_1d_nb`.""" pct_trend_labels = np.empty(pivots.shape, dtype=float_) for col in prange(pivots.shape[1]): pct_trend_labels[:, col] = pct_trend_labels_1d_nb( high[:, col], low[:, col], pivots[:, col], normalize=normalize, ) return pct_trend_labels @register_jitted(cache=True) def trend_labels_1d_nb( high: tp.Array1d, low: tp.Array1d, up_th: tp.FlexArray1dLike, down_th: tp.FlexArray1dLike, mode: int = TrendLabelMode.Binary, ) -> tp.Array2d: """Trend labels based on `vectorbtpro.labels.enums.TrendLabelMode`.""" pivots = pivots_1d_nb(high, low, up_th, down_th) if mode == TrendLabelMode.Binary: return bin_trend_labels_1d_nb(pivots) if mode == TrendLabelMode.BinaryCont: return binc_trend_labels_1d_nb(high, low, pivots) if mode == TrendLabelMode.BinaryContSat: return bincs_trend_labels_1d_nb(high, low, pivots, up_th, down_th) if mode == TrendLabelMode.PctChange: return pct_trend_labels_1d_nb(high, low, pivots, normalize=False) if mode == TrendLabelMode.PctChangeNorm: return pct_trend_labels_1d_nb(high, low, pivots, normalize=True) raise ValueError("Invalid trend mode") @register_chunkable( size=ch.ArraySizer(arg_query="high", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), up_th=base_ch.FlexArraySlicer(axis=1), down_th=base_ch.FlexArraySlicer(axis=1), mode=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def trend_labels_nb( high: tp.Array2d, low: tp.Array2d, up_th: tp.FlexArray2dLike, down_th: tp.FlexArray2dLike, mode: tp.FlexArray1dLike = TrendLabelMode.Binary, ) -> tp.Array2d: """2-dim version of `trend_labels_1d_nb`.""" up_th_ = to_2d_array_nb(np.asarray(up_th)) down_th_ = to_2d_array_nb(np.asarray(down_th)) mode_ = to_1d_array_nb(np.asarray(mode)) trend_labels = np.empty(high.shape, dtype=float_) for col in prange(high.shape[1]): trend_labels[:, col] = trend_labels_1d_nb( high[:, col], low[:, col], flex_select_col_nb(up_th_, col), flex_select_col_nb(down_th_, col), mode=flex_select_1d_nb(mode_, col), ) return trend_labels # ############# BOLB ############# # @register_jitted(cache=True) def breakout_labels_1d_nb( high: tp.Array1d, low: tp.Array1d, window: int = 14, up_th: tp.FlexArray1dLike = np.inf, down_th: tp.FlexArray1dLike = np.inf, wait: int = 1, ) -> tp.Array1d: """For each value, return 1 if any value in the next period is greater than the positive threshold (in %), -1 if less than the negative threshold, and 0 otherwise. First hit wins. Continue search if both thresholds were hit at the same time.""" up_th_ = to_1d_array_nb(np.asarray(up_th)) down_th_ = to_1d_array_nb(np.asarray(down_th)) breakout_labels = np.full(high.shape, 0, dtype=float_) for i in range(high.shape[0]): if not np.isnan(high[i]) and not np.isnan(low[i]): _up_th = 1 + abs(flex_select_1d_nb(up_th_, i)) _down_th = 1 - abs(flex_select_1d_nb(down_th_, i)) for j in range(i + wait, min(i + window + wait, high.shape[0])): if not np.isnan(high[j]) and not np.isnan(low[j]): if not np.isnan(_up_th) and high[j] >= low[i] * _up_th: breakout_labels[i] = 1 break if not np.isnan(_down_th) and low[j] <= high[i] * _down_th: if breakout_labels[i] == 1: breakout_labels[i] = 0 continue breakout_labels[i] = -1 break return breakout_labels @register_chunkable( size=ch.ArraySizer(arg_query="high", axis=1), arg_take_spec=dict( high=ch.ArraySlicer(axis=1), low=ch.ArraySlicer(axis=1), window=base_ch.FlexArraySlicer(), up_th=base_ch.FlexArraySlicer(axis=1), down_th=base_ch.FlexArraySlicer(axis=1), wait=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def breakout_labels_nb( high: tp.Array2d, low: tp.Array2d, window: tp.FlexArray1dLike = 14, up_th: tp.FlexArray2dLike = np.inf, down_th: tp.FlexArray2dLike = np.inf, wait: tp.FlexArray1dLike = 1, ) -> tp.Array2d: """2-dim version of `breakout_labels_1d_nb`.""" window_ = to_1d_array_nb(np.asarray(window)) up_th_ = to_2d_array_nb(np.asarray(up_th)) down_th_ = to_2d_array_nb(np.asarray(down_th)) wait_ = to_1d_array_nb(np.asarray(wait)) breakout_labels = np.empty(high.shape, dtype=float_) for col in prange(high.shape[1]): breakout_labels[:, col] = breakout_labels_1d_nb( high[:, col], low[:, col], window=flex_select_1d_nb(window_, col), up_th=flex_select_col_nb(up_th_, col), down_th=flex_select_col_nb(down_th_, col), wait=flex_select_1d_nb(wait_, col), ) return breakout_labels # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with OHLC(V) data.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.ohlcv.accessors import * from vectorbtpro.ohlcv.nb import * __exclude_from__all__ = [ "enums", ] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom Pandas accessors for OHLC(V) data. Methods can be accessed as follows: * `OHLCVDFAccessor` -> `pd.DataFrame.vbt.ohlcv.*` The accessors inherit `vectorbtpro.generic.accessors`. !!! note Accessors do not utilize caching. ## Column names By default, vectorbt searches for columns with names 'open', 'high', 'low', 'close', and 'volume' (case doesn't matter). You can change the naming either using `feature_map` in `vectorbtpro._settings.ohlcv`, or by providing `feature_map` directly to the accessor. ```pycon >>> from vectorbtpro import * >>> df = pd.DataFrame({ ... 'my_open1': [2, 3, 4, 3.5, 2.5], ... 'my_high2': [3, 4, 4.5, 4, 3], ... 'my_low3': [1.5, 2.5, 3.5, 2.5, 1.5], ... 'my_close4': [2.5, 3.5, 4, 3, 2], ... 'my_volume5': [10, 11, 10, 9, 10] ... }) >>> df.vbt.ohlcv.get_feature('open') None >>> my_feature_map = { ... "my_open1": "Open", ... "my_high2": "High", ... "my_low3": "Low", ... "my_close4": "Close", ... "my_volume5": "Volume", ... } >>> ohlcv_acc = df.vbt.ohlcv(freq='d', feature_map=my_feature_map) >>> ohlcv_acc.get_feature('open') 0 2.0 1 3.0 2 4.0 3 3.5 4 2.5 Name: my_open1, dtype: float64 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `OHLCVDFAccessor.metrics`. ```pycon >>> ohlcv_acc.stats() Start 0 End 4 Period 5 days 00:00:00 First Price 2.0 Lowest Price 1.5 Highest Price 4.5 Last Price 2.0 First Volume 10 Lowest Volume 9 Highest Volume 11 Last Volume 10 Name: agg_stats, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `OHLCVDFAccessor.subplots`. `OHLCVDFAccessor` class has a single subplot based on `OHLCVDFAccessor.plot` (without volume): ```pycon >>> ohlcv_acc.plots(settings=dict(ohlc_type='candlestick')).show() ``` ![](/assets/images/api/ohlcv_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ohlcv_plots.dark.svg#only-dark){: .iimg loading=lazy } """ import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.accessors import register_df_vbt_accessor from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.base.reshaping import to_1d_array, to_2d_array from vectorbtpro.data.base import OHLCDataMixin from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.accessors import GenericAccessor, GenericDFAccessor from vectorbtpro.ohlcv import nb, enums from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import hybrid_property from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg if tp.TYPE_CHECKING: from vectorbtpro.data.base import Data as DataT else: DataT = "Data" __all__ = [ "OHLCVDFAccessor", ] __pdoc__ = {} OHLCVDFAccessorT = tp.TypeVar("OHLCVDFAccessorT", bound="OHLCVDFAccessor") @register_df_vbt_accessor("ohlcv") class OHLCVDFAccessor(OHLCDataMixin, GenericDFAccessor): """Accessor on top of OHLCV data. For DataFrames only. Accessible via `pd.DataFrame.vbt.ohlcv`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, feature_map: tp.KwargsLike = None, **kwargs, ) -> None: GenericDFAccessor.__init__(self, wrapper, obj=obj, feature_map=feature_map, **kwargs) self._feature_map = feature_map @hybrid_property def df_accessor_cls(cls_or_self) -> tp.Type["OHLCVDFAccessor"]: """Accessor class for `pd.DataFrame`.""" return OHLCVDFAccessor @property def feature_map(self) -> tp.Kwargs: """Column names.""" from vectorbtpro._settings import settings ohlcv_cfg = settings["ohlcv"] return merge_dicts(ohlcv_cfg["feature_map"], self._feature_map) @property def feature_wrapper(self) -> ArrayWrapper: new_columns = self.wrapper.columns.map(lambda x: self.feature_map[x] if x in self.feature_map else x) return self.wrapper.replace(columns=new_columns) @property def symbol_wrapper(self) -> ArrayWrapper: return ArrayWrapper(self.wrapper.index, [None], 1) def select_feature_idxs(self: OHLCVDFAccessorT, idxs: tp.MaybeSequence[int], **kwargs) -> OHLCVDFAccessorT: return self.iloc[:, idxs] def select_symbol_idxs(self: OHLCVDFAccessorT, idxs: tp.MaybeSequence[int], **kwargs) -> OHLCVDFAccessorT: raise NotImplementedError def get( self, features: tp.Union[None, tp.MaybeFeatures] = None, symbols: tp.Union[None, tp.MaybeSymbols] = None, feature: tp.Optional[tp.Feature] = None, symbol: tp.Optional[tp.Symbol] = None, **kwargs, ) -> tp.SeriesFrame: if features is not None and feature is not None: raise ValueError("Must provide either features or feature, not both") if symbols is not None or symbol is not None: raise ValueError("Cannot provide symbols") if feature is not None: features = feature single_feature = True else: if features is None: return self.obj single_feature = not self.has_multiple_keys(features) if not single_feature: feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features] else: feature_idxs = self.get_feature_idx(features, raise_error=True) return self.obj.iloc[:, feature_idxs] # ############# Conversion ############# # def to_data(self, data_cls: tp.Optional[tp.Type[DataT]] = None, **kwargs) -> DataT: """Convert to a `vectorbtpro.data.base.Data` instance.""" if data_cls is None: from vectorbtpro.data.base import Data data_cls = Data return data_cls.from_data(self.obj, columns_are_symbols=False, **kwargs) # ############# Transforming ############# # def mirror_ohlc( self: OHLCVDFAccessorT, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, start_value: tp.ArrayLike = np.nan, ref_feature: tp.ArrayLike = -1, ) -> tp.Frame: """Mirror OHLC features.""" if isinstance(ref_feature, str): ref_feature = map_enum_fields(ref_feature, enums.PriceFeature) open_idx = self.get_feature_idx("Open") high_idx = self.get_feature_idx("High") low_idx = self.get_feature_idx("Low") close_idx = self.get_feature_idx("Close") func = jit_reg.resolve_option(nb.mirror_ohlc_nb, jitted) func = ch_reg.resolve_option(func, chunked) new_open, new_high, new_low, new_close = func( self.symbol_wrapper.shape_2d, open=to_2d_array(self.obj.iloc[:, open_idx]) if open_idx is not None else None, high=to_2d_array(self.obj.iloc[:, high_idx]) if high_idx is not None else None, low=to_2d_array(self.obj.iloc[:, low_idx]) if low_idx is not None else None, close=to_2d_array(self.obj.iloc[:, close_idx]) if close_idx is not None else None, start_value=to_1d_array(start_value), ref_feature=to_1d_array(ref_feature), ) df = self.obj.copy() if open_idx is not None: df.iloc[:, open_idx] = new_open[:, 0] if high_idx is not None: df.iloc[:, high_idx] = new_high[:, 0] if low_idx is not None: df.iloc[:, low_idx] = new_low[:, 0] if close_idx is not None: df.iloc[:, close_idx] = new_close[:, 0] return df def resample(self: OHLCVDFAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> OHLCVDFAccessorT: """Perform resampling on `OHLCVDFAccessor`.""" if wrapper_meta is None: wrapper_meta = self.wrapper.resample_meta(*args, **kwargs) sr_dct = {} for feature in self.feature_wrapper.columns: if isinstance(feature, str) and feature.lower() == "open": sr_dct[feature] = self.obj[feature].vbt.resample_apply( wrapper_meta["resampler"], generic_nb.first_reduce_nb, ) elif isinstance(feature, str) and feature.lower() == "high": sr_dct[feature] = self.obj[feature].vbt.resample_apply( wrapper_meta["resampler"], generic_nb.max_reduce_nb, ) elif isinstance(feature, str) and feature.lower() == "low": sr_dct[feature] = self.obj[feature].vbt.resample_apply( wrapper_meta["resampler"], generic_nb.min_reduce_nb, ) elif isinstance(feature, str) and feature.lower() == "close": sr_dct[feature] = self.obj[feature].vbt.resample_apply( wrapper_meta["resampler"], generic_nb.last_reduce_nb, ) elif isinstance(feature, str) and feature.lower() == "volume": sr_dct[feature] = self.obj[feature].vbt.resample_apply( wrapper_meta["resampler"], generic_nb.sum_reduce_nb, ) else: raise ValueError(f"Cannot match feature '{feature}'") new_obj = pd.DataFrame(sr_dct) return self.replace( wrapper=wrapper_meta["new_wrapper"], obj=new_obj, ) # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `OHLCVDFAccessor.stats`. Merges `vectorbtpro.generic.accessors.GenericAccessor.stats_defaults` and `stats` from `vectorbtpro._settings.ohlcv`.""" from vectorbtpro._settings import settings ohlcv_stats_cfg = settings["ohlcv"]["stats"] return merge_dicts(GenericAccessor.stats_defaults.__get__(self), ohlcv_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), first_price=dict( title="First Price", calc_func=lambda ohlc: generic_nb.bfill_1d_nb(ohlc.values.flatten())[0], resolve_ohlc=True, tags=["ohlcv", "ohlc"], ), lowest_price=dict( title="Lowest Price", calc_func=lambda ohlc: ohlc.values.min(), resolve_ohlc=True, tags=["ohlcv", "ohlc"], ), highest_price=dict( title="Highest Price", calc_func=lambda ohlc: ohlc.values.max(), resolve_ohlc=True, tags=["ohlcv", "ohlc"], ), last_price=dict( title="Last Price", calc_func=lambda ohlc: generic_nb.ffill_1d_nb(ohlc.values.flatten())[-1], resolve_ohlc=True, tags=["ohlcv", "ohlc"], ), first_volume=dict( title="First Volume", calc_func=lambda volume: generic_nb.bfill_1d_nb(volume.values)[0], resolve_volume=True, tags=["ohlcv", "volume"], ), lowest_volume=dict( title="Lowest Volume", calc_func=lambda volume: volume.values.min(), resolve_volume=True, tags=["ohlcv", "volume"], ), highest_volume=dict( title="Highest Volume", calc_func=lambda volume: volume.values.max(), resolve_volume=True, tags=["ohlcv", "volume"], ), last_volume=dict( title="Last Volume", calc_func=lambda volume: generic_nb.ffill_1d_nb(volume.values)[-1], resolve_volume=True, tags=["ohlcv", "volume"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot_ohlc( self, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot OHLC data. Args: ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] ohlcv_cfg = settings["ohlcv"] if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} # Set up figure if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if ohlc_type is None: ohlc_type = ohlcv_cfg["ohlc_type"] if isinstance(ohlc_type, str): if ohlc_type.lower() == "ohlc": plot_obj = go.Ohlc elif ohlc_type.lower() == "candlestick": plot_obj = go.Candlestick else: raise ValueError("Plot type can be either 'OHLC' or 'Candlestick'") else: plot_obj = ohlc_type def_trace_kwargs = dict( x=self.wrapper.index, open=self.open, high=self.high, low=self.low, close=self.close, name="OHLC", increasing=dict( fillcolor=plotting_cfg["color_schema"]["increasing"], line=dict(color=plotting_cfg["color_schema"]["increasing"]), ), decreasing=dict( fillcolor=plotting_cfg["color_schema"]["decreasing"], line=dict(color=plotting_cfg["color_schema"]["decreasing"]), ), opacity=0.75, ) if plot_obj is go.Ohlc: del def_trace_kwargs["increasing"]["fillcolor"] del def_trace_kwargs["decreasing"]["fillcolor"] _trace_kwargs = merge_dicts(def_trace_kwargs, trace_kwargs) ohlc = plot_obj(**_trace_kwargs) fig.add_trace(ohlc, **add_trace_kwargs) xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] if "rangeslider_visible" not in layout_kwargs.get(xaxis, {}): fig.update_layout({xaxis: dict(rangeslider_visible=False)}) return fig def plot_volume( self, trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot volume data. Args: trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if trace_kwargs is None: trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} # Set up figure if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) marker_colors = np.empty(self.volume.shape, dtype=object) mask_greater = (self.close.values - self.open.values) > 0 mask_less = (self.close.values - self.open.values) < 0 marker_colors[mask_greater] = plotting_cfg["color_schema"]["increasing"] marker_colors[mask_less] = plotting_cfg["color_schema"]["decreasing"] marker_colors[~(mask_greater | mask_less)] = plotting_cfg["color_schema"]["gray"] _trace_kwargs = merge_dicts( dict( x=self.wrapper.index, y=self.volume, marker=dict(color=marker_colors, line_width=0), opacity=0.5, name="Volume", ), trace_kwargs, ) volume_bar = go.Bar(**_trace_kwargs) fig.add_trace(volume_bar, **add_trace_kwargs) return fig def plot( self, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, plot_volume: tp.Optional[bool] = None, ohlc_trace_kwargs: tp.KwargsLike = None, volume_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, volume_add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot OHLC(V) data. Args: ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. plot_volume (bool): Whether to plot volume beneath. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. volume_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar`. add_trace_kwargs (dict): Keyword arguments passed to `add_trace` for OHLC. volume_add_trace_kwargs (dict): Keyword arguments passed to `add_trace` for volume. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> vbt.YFData.pull("BTC-USD").get().vbt.ohlcv.plot().show() ``` [=100% "100%"]{: .candystripe .candystripe-animate } ![](/assets/images/api/ohlcv_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/ohlcv_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure, make_subplots if plot_volume is None: plot_volume = self.volume is not None if plot_volume: add_trace_kwargs = merge_dicts(dict(row=1, col=1), add_trace_kwargs) volume_add_trace_kwargs = merge_dicts(dict(row=2, col=1), volume_add_trace_kwargs) # Set up figure if fig is None: if plot_volume: fig = make_subplots( rows=2, cols=1, shared_xaxes=True, vertical_spacing=0, row_heights=[0.7, 0.3], ) else: fig = make_figure() fig.update_layout( showlegend=True, xaxis=dict(showgrid=True), yaxis=dict(showgrid=True), ) if plot_volume: fig.update_layout( xaxis2=dict(showgrid=True), yaxis2=dict(showgrid=True), ) fig.update_layout(**layout_kwargs) fig = self.plot_ohlc( ohlc_type=ohlc_type, trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if plot_volume: fig = self.plot_volume( trace_kwargs=volume_trace_kwargs, add_trace_kwargs=volume_add_trace_kwargs, fig=fig, ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `OHLCVDFAccessor.plots`. Merges `vectorbtpro.generic.accessors.GenericAccessor.plots_defaults` and `plots` from `vectorbtpro._settings.ohlcv`.""" from vectorbtpro._settings import settings ohlcv_plots_cfg = settings["ohlcv"]["plots"] return merge_dicts(GenericAccessor.plots_defaults.__get__(self), ohlcv_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( title="OHLC", xaxis_kwargs=dict(showgrid=True, rangeslider_visible=False), yaxis_kwargs=dict(showgrid=True), check_is_not_grouped=True, plot_func="plot", plot_volume=False, tags="ohlcv", ) ) ) @property def subplots(self) -> Config: return self._subplots OHLCVDFAccessor.override_metrics_doc(__pdoc__) OHLCVDFAccessor.override_subplots_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for OHLC(V) data. Defines enums and other schemas for `vectorbtpro.ohlcv`.""" from vectorbtpro import _typing as tp from vectorbtpro.utils.formatting import prettify __pdoc__all__ = __all__ = [ "PriceFeature", ] __pdoc__ = {} # ############# Enums ############# # class PriceFeatureT(tp.NamedTuple): Open: int = 0 High: int = 1 Low: int = 2 Close: int = 3 PriceFeature = PriceFeatureT() """_""" __pdoc__[ "PriceFeature" ] = f"""Price feature. ```python {prettify(PriceFeature)} ``` """ # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for OHLCV. !!! note vectorbt treats matrices as first-class citizens and expects input arrays to be 2-dim, unless function has suffix `_1d` or is meant to be input to another function. Data is processed along index (axis 0).""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_pr_nb, flex_select_1d_nb from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch from vectorbtpro.ohlcv.enums import PriceFeature __all__ = [] @register_jitted(cache=True) def ohlc_every_1d_nb(price: tp.Array1d, n: tp.FlexArray1dLike) -> tp.Array2d: """Aggregate every `n` price points into an OHLC point.""" n_ = to_1d_array_nb(np.asarray(n)) out = np.empty((price.shape[0], 4), dtype=float_) vmin = np.inf vmax = -np.inf k = 0 start_i = 0 for i in range(price.shape[0]): _n = flex_select_1d_pr_nb(n_, k) if _n <= 0: out[k, 0] = np.nan out[k, 1] = np.nan out[k, 2] = np.nan out[k, 3] = np.nan vmin = np.inf vmax = -np.inf if i < price.shape[0] - 1: k = k + 1 continue if price[i] < vmin: vmin = price[i] if price[i] > vmax: vmax = price[i] if i == start_i: out[k, 0] = price[i] if i == start_i + _n - 1 or i == price.shape[0] - 1: out[k, 1] = vmax out[k, 2] = vmin out[k, 3] = price[i] vmin = np.inf vmax = -np.inf if i < price.shape[0] - 1: k = k + 1 start_i = start_i + _n return out[: k + 1] @register_jitted(cache=True) def mirror_ohlc_1d_nb( n_rows: int, open: tp.Optional[tp.Array1d] = None, high: tp.Optional[tp.Array1d] = None, low: tp.Optional[tp.Array1d] = None, close: tp.Optional[tp.Array1d] = None, start_value: float = np.nan, ref_feature: int = -1, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Mirror OHLC.""" new_open = np.empty(n_rows, dtype=float_) new_high = np.empty(n_rows, dtype=float_) new_low = np.empty(n_rows, dtype=float_) new_close = np.empty(n_rows, dtype=float_) cumsum = 0.0 first_idx = -1 factor = 1.0 for i in range(n_rows): _open = open[i] if open is not None else np.nan _high = high[i] if high is not None else np.nan _low = low[i] if low is not None else np.nan _close = close[i] if close is not None else np.nan if ref_feature == PriceFeature.Open or (ref_feature == -1 and not np.isnan(_open)): if first_idx == -1: first_idx = i if not np.isnan(start_value): new_open[i] = start_value else: new_open[i] = _open factor = new_open[i] / _open new_high[i] = _high * factor if not np.isnan(_high) else np.nan new_low[i] = _low * factor if not np.isnan(_low) else np.nan new_close[i] = _close * factor if not np.isnan(_close) else np.nan else: prev_open = open[i - 1] if open is not None else np.nan cumsum += -np.log(_open / prev_open) new_open[i] = open[first_idx] * np.exp(cumsum) * factor new_high[i] = (_open / _low) * new_open[i] if not np.isnan(_low) else np.nan new_low[i] = (_open / _high) * new_open[i] if not np.isnan(_high) else np.nan new_close[i] = (_open / _close) * new_open[i] if not np.isnan(_close) else np.nan elif ref_feature == PriceFeature.Close or (ref_feature == -1 and not np.isnan(_close)): if first_idx == -1: first_idx = i if not np.isnan(start_value): new_close[i] = start_value else: new_close[i] = _close factor = new_close[i] / _close new_open[i] = _open * factor if not np.isnan(_open) else np.nan new_high[i] = _high * factor if not np.isnan(_high) else np.nan new_low[i] = _low * factor if not np.isnan(_low) else np.nan else: prev_close = close[i - 1] if close is not None else np.nan cumsum += -np.log(_close / prev_close) new_close[i] = close[first_idx] * np.exp(cumsum) * factor new_open[i] = (_close / _open) * new_close[i] if not np.isnan(_open) else np.nan new_high[i] = (_close / _low) * new_close[i] if not np.isnan(_low) else np.nan new_low[i] = (_close / _high) * new_close[i] if not np.isnan(_high) else np.nan elif ref_feature == PriceFeature.High or (ref_feature == -1 and not np.isnan(_high)): if first_idx == -1: first_idx = i if not np.isnan(start_value): new_high[i] = start_value else: new_high[i] = _high factor = new_high[i] / _high new_open[i] = _open * factor if not np.isnan(_open) else np.nan new_low[i] = _low * factor * new_high[i] if not np.isnan(_low) else np.nan new_close[i] = _close * factor * new_high[i] if not np.isnan(_close) else np.nan else: prev_high = high[i - 1] if high is not None else np.nan cumsum += -np.log(_high / prev_high) new_high[i] = high[first_idx] * np.exp(cumsum) * factor new_open[i] = (_high / _open) * new_high[i] if not np.isnan(_open) else np.nan new_high[i] = (_high / _low) * new_high[i] if not np.isnan(_low) else np.nan new_close[i] = (_high / _close) * new_high[i] if not np.isnan(_close) else np.nan elif ref_feature == PriceFeature.Low or (ref_feature == -1 and not np.isnan(_low)): if first_idx == -1: first_idx = i if not np.isnan(start_value): new_low[i] = start_value else: new_low[i] = _low factor = new_low[i] / _low new_open[i] = _open * factor if not np.isnan(_open) else np.nan new_high[i] = _high * factor if not np.isnan(_high) else np.nan new_close[i] = _close * factor if not np.isnan(_close) else np.nan else: prev_low = low[i - 1] if low is not None else np.nan cumsum += -np.log(_low / prev_low) new_low[i] = low[first_idx] * np.exp(cumsum) * factor new_open[i] = (_low / _open) * new_low[i] if not np.isnan(_open) else np.nan new_high[i] = (_low / _high) * new_low[i] if not np.isnan(_high) else np.nan new_close[i] = (_low / _close) * new_low[i] if not np.isnan(_close) else np.nan else: new_open[i] = np.nan new_high[i] = np.nan new_low[i] = np.nan new_close[i] = np.nan return new_open, new_high, new_low, new_close @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), open=base_ch.ArraySlicer(axis=1), high=base_ch.ArraySlicer(axis=1), low=base_ch.ArraySlicer(axis=1), close=base_ch.ArraySlicer(axis=1), start_value=base_ch.FlexArraySlicer(), ref_feature=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def mirror_ohlc_nb( target_shape: tp.Shape, open: tp.Optional[tp.Array2d] = None, high: tp.Optional[tp.Array2d] = None, low: tp.Optional[tp.Array2d] = None, close: tp.Optional[tp.Array2d] = None, start_value: tp.FlexArray1dLike = np.nan, ref_feature: tp.FlexArray1dLike = -1, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]: """2-dim version of `mirror_ohlc_1d_nb`.""" start_value_ = to_1d_array_nb(np.asarray(start_value)) ref_feature_ = to_1d_array_nb(np.asarray(ref_feature)) new_open = np.empty(target_shape, dtype=float_) new_high = np.empty(target_shape, dtype=float_) new_low = np.empty(target_shape, dtype=float_) new_close = np.empty(target_shape, dtype=float_) for col in prange(target_shape[1]): new_open[:, col], new_high[:, col], new_low[:, col], new_close[:, col] = mirror_ohlc_1d_nb( target_shape[0], open[:, col] if open is not None else None, high[:, col] if high is not None else None, low[:, col] if low is not None else None, close[:, col] if close is not None else None, start_value=flex_select_1d_nb(start_value_, col), ref_feature=flex_select_1d_nb(ref_feature_, col), ) return new_open, new_high, new_low, new_close """Numba-compiled functions for working with portfolio. Provides an arsenal of Numba-compiled functions that are used for portfolio simulation, such as generating and filling orders. These only accept NumPy arrays and other Numba-compatible types. !!! note vectorbt treats matrices as first-class citizens and expects input arrays to be 2-dim, unless function has suffix `_1d` or is meant to be input to another function. All functions passed as argument must be Numba-compiled. Records must retain the order they were created in. !!! warning Accumulation of roundoff error possible. See [here](https://en.wikipedia.org/wiki/Round-off_error#Accumulation_of_roundoff_error) for explanation. Rounding errors can cause trades and positions to not close properly: ```pycon >>> print('%.50f' % 0.1) # has positive error 0.10000000000000000555111512312578270211815834045410 >>> # many buy transactions with positive error -> cannot close position >>> sum([0.1 for _ in range(1000000)]) - 100000 1.3328826753422618e-06 >>> print('%.50f' % 0.3) # has negative error 0.29999999999999998889776975374843459576368331909180 >>> # many sell transactions with negative error -> cannot close position >>> 300000 - sum([0.3 for _ in range(1000000)]) 5.657668225467205e-06 ``` While vectorbt has implemented tolerance checks when comparing floats for equality, adding/subtracting small amounts large number of times may still introduce a noticable error that cannot be corrected post factum. To mitigate this issue, avoid repeating lots of micro-transactions of the same sign. For example, reduce by `np.inf` or `position_now` to close a long/short position. See `vectorbtpro.utils.math_` for current tolerance values. !!! warning Make sure to use `parallel=True` only if your columns are independent. """ from vectorbtpro.portfolio.nb.analysis import * from vectorbtpro.portfolio.nb.core import * from vectorbtpro.portfolio.nb.ctx_helpers import * from vectorbtpro.portfolio.nb.from_order_func import * from vectorbtpro.portfolio.nb.from_orders import * from vectorbtpro.portfolio.nb.from_signals import * from vectorbtpro.portfolio.nb.iter_ import * from vectorbtpro.portfolio.nb.records import * __all__ = [] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for portfolio analysis.""" from numba import prange from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.portfolio import chunking as portfolio_ch from vectorbtpro.portfolio.nb.core import * from vectorbtpro.records import chunking as records_ch from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.returns import nb as returns_nb_ from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.math_ import is_close_nb, add_nb from vectorbtpro.utils.template import RepFunc # ############# Assets ############# # @register_jitted(cache=True) def get_long_size_nb(position_before: float, position_now: float) -> float: """Get long size.""" if position_before <= 0 and position_now <= 0: return 0.0 if position_before >= 0 and position_now < 0: return -position_before if position_before < 0 and position_now >= 0: return position_now return add_nb(position_now, -position_before) @register_jitted(cache=True) def get_short_size_nb(position_before: float, position_now: float) -> float: """Get short size.""" if position_before >= 0 and position_now >= 0: return 0.0 if position_before >= 0 and position_now < 0: return -position_now if position_before < 0 and position_now >= 0: return position_before return add_nb(position_before, -position_now) @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), direction=None, init_position=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def asset_flow_nb( target_shape: tp.Shape, order_records: tp.RecordArray, col_map: tp.GroupMap, direction: int = Direction.Both, init_position: tp.FlexArray1dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get asset flow series per column. Returns the total transacted amount of assets at each time step.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) out = np.full(target_shape, np.nan, dtype=float_) col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] out[_sim_start:_sim_end, col] = 0.0 if _sim_start >= _sim_end: continue col_len = col_lens[col] if col_len == 0: continue last_id = -1 position_now = flex_select_1d_pc_nb(init_position_, col) for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] i = order_record["idx"] side = order_record["side"] size = order_record["size"] if side == OrderSide.Sell: size *= -1 new_position_now = add_nb(position_now, size) if direction == Direction.LongOnly: asset_flow = get_long_size_nb(position_now, new_position_now) elif direction == Direction.ShortOnly: asset_flow = get_short_size_nb(position_now, new_position_now) else: asset_flow = size out[i, col] = add_nb(out[i, col], asset_flow) position_now = new_position_now return out @register_chunkable( size=ch.ArraySizer(arg_query="asset_flow", axis=1), arg_take_spec=dict( asset_flow=ch.ArraySlicer(axis=1), direction=None, init_position=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def assets_nb( asset_flow: tp.Array2d, direction: int = Direction.Both, init_position: tp.FlexArray1dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get asset series per column. Returns the current position at each time step.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) out = np.full(asset_flow.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=asset_flow.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(asset_flow.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue position_now = flex_select_1d_pc_nb(init_position_, col) for i in range(_sim_start, _sim_end): flow_value = asset_flow[i, col] position_now = add_nb(position_now, flow_value) if direction == Direction.Both: out[i, col] = position_now elif direction == Direction.LongOnly and position_now > 0: out[i, col] = position_now elif direction == Direction.ShortOnly and position_now < 0: out[i, col] = -position_now else: out[i, col] = 0.0 return out @register_chunkable( size=ch.ArraySizer(arg_query="assets", axis=1), arg_take_spec=dict( assets=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def position_mask_nb( assets: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get position mask per column.""" out = np.full(assets.shape, False, dtype=np.bool_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=assets.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(assets.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if assets[i, col] != 0: out[i, col] = True return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( assets=base_ch.array_gl_slicer, group_lens=ch.ArraySlicer(axis=0), sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def position_mask_grouped_nb( assets: tp.Array2d, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get position mask per group.""" out = np.full((assets.shape[0], len(group_lens)), False, dtype=np.bool_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=assets.shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if not np.isnan(assets[i, col]) and assets[i, col] != 0: out[i, group] = True return out @register_chunkable( size=ch.ArraySizer(arg_query="assets", axis=1), arg_take_spec=dict( assets=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def position_coverage_nb( assets: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Get position mask per column.""" out = np.full(assets.shape[1], 0.0, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=assets.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(assets.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue hit_elements = 0 for i in range(_sim_start, _sim_end): if assets[i, col] != 0: hit_elements += 1 out[col] = hit_elements / (_sim_end - _sim_start) return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( assets=base_ch.array_gl_slicer, group_lens=ch.ArraySlicer(axis=0), granular_groups=None, sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def position_coverage_grouped_nb( assets: tp.Array2d, group_lens: tp.GroupLens, granular_groups: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Get position coverage per group.""" out = np.full(len(group_lens), 0.0, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=assets.shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] n_elements = 0 hit_elements = 0 if granular_groups: for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue n_elements += _sim_end - _sim_start for i in range(_sim_start, _sim_end): if not np.isnan(assets[i, col]) and assets[i, col] != 0: hit_elements += 1 else: min_sim_start = assets.shape[0] max_sim_end = 0 for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if _sim_start < min_sim_start: min_sim_start = _sim_start if _sim_end > max_sim_end: max_sim_end = _sim_end if min_sim_start >= max_sim_end: continue n_elements = max_sim_end - min_sim_start for i in range(min_sim_start, max_sim_end): for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if not np.isnan(assets[i, col]) and assets[i, col] != 0: hit_elements += 1 break if n_elements == 0: out[group] = np.nan else: out[group] = hit_elements / n_elements return out # ############# Cash ############# # @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, cash_deposits_raw=RepFunc(portfolio_ch.get_cash_deposits_slicer), split_shared=None, weights=base_ch.flex_1d_array_gl_slicer, sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_deposits_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, cash_deposits_raw: tp.FlexArray2dLike = 0.0, split_shared: bool = False, weights: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get cash deposit series per column.""" cash_deposits_raw_ = to_2d_array_nb(np.asarray(cash_deposits_raw)) if weights is None: weights_ = np.full(target_shape[1], np.nan, dtype=float_) else: weights_ = to_1d_array_nb(np.asarray(weights).astype(float_)) out = np.full(target_shape, np.nan, dtype=float_) if not cash_sharing: sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(target_shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _weights = flex_select_1d_pc_nb(weights_, col) for i in range(_sim_start, _sim_end): _cash_deposits = flex_select_nb(cash_deposits_raw_, i, col) if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[i, col] = _weights * _cash_deposits else: out[i, col] = _cash_deposits else: group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _weights = flex_select_1d_pc_nb(weights_, col) for i in range(_sim_start, _sim_end): _cash_deposits = flex_select_nb(cash_deposits_raw_, i, group) if split_shared: if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[i, col] = _weights * _cash_deposits / (to_col - from_col) else: out[i, col] = _cash_deposits / (to_col - from_col) else: if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[i, col] = _weights * _cash_deposits else: out[i, col] = _cash_deposits return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, cash_deposits_raw=RepFunc(portfolio_ch.get_cash_deposits_slicer), weights=base_ch.flex_1d_array_gl_slicer, sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_deposits_grouped_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, cash_deposits_raw: tp.FlexArray2dLike = 0.0, weights: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get cash deposit series per group.""" cash_deposits_raw_ = to_2d_array_nb(np.asarray(cash_deposits_raw)) if weights is None: weights_ = np.full(target_shape[1], np.nan, dtype=float_) else: weights_ = to_1d_array_nb(np.asarray(weights).astype(float_)) out = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) if cash_sharing: group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_grouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): _sim_start = sim_start_[group] _sim_end = sim_end_[group] if _sim_start >= _sim_end: continue from_col = group_start_idxs[group] to_col = group_end_idxs[group] for i in range(_sim_start, _sim_end): _cash_deposits = flex_select_nb(cash_deposits_raw_, i, group) if np.isnan(_cash_deposits) or _cash_deposits == 0: out[i, group] = _cash_deposits continue group_weight = 0.0 for col in range(from_col, to_col): _weights = flex_select_1d_pc_nb(weights_, col) if not np.isnan(group_weight) and not np.isnan(_weights): group_weight += _weights else: group_weight = np.nan break if not np.isnan(group_weight): group_weight /= group_lens[group] if not np.isnan(group_weight) and not is_close_nb(group_weight, 1.0): out[i, group] = group_weight * _cash_deposits else: out[i, group] = _cash_deposits else: group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _weights = flex_select_1d_pc_nb(weights_, col) for i in range(_sim_start, _sim_end): _cash_deposits = flex_select_nb(cash_deposits_raw_, i, col) if np.isnan(out[i, group]): out[i, group] = 0.0 if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[i, group] += _weights * _cash_deposits else: out[i, group] += _cash_deposits return out @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), cash_earnings_raw=base_ch.FlexArraySlicer(axis=1), weights=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_earnings_nb( target_shape: tp.Shape, cash_earnings_raw: tp.FlexArray2dLike = 0.0, weights: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get cash earning series per column.""" cash_earnings_raw_ = to_2d_array_nb(np.asarray(cash_earnings_raw)) if weights is None: weights_ = np.full(target_shape[1], np.nan, dtype=float_) else: weights_ = to_1d_array_nb(np.asarray(weights).astype(float_)) out = np.full(target_shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(target_shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _weights = flex_select_1d_pc_nb(weights_, col) for i in range(_sim_start, _sim_end): _cash_earnings = flex_select_nb(cash_earnings_raw_, i, col) if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[i, col] = _weights * _cash_earnings else: out[i, col] = _cash_earnings return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_earnings_raw=base_ch.flex_array_gl_slicer, weights=base_ch.flex_1d_array_gl_slicer, sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_earnings_grouped_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, cash_earnings_raw: tp.FlexArray2dLike = 0.0, weights: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get cash earning series per group.""" cash_earnings_raw_ = to_2d_array_nb(np.asarray(cash_earnings_raw)) if weights is None: weights_ = np.full(target_shape[1], np.nan, dtype=float_) else: weights_ = to_1d_array_nb(np.asarray(weights).astype(float_)) out = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _weights = flex_select_1d_pc_nb(weights_, col) for i in range(_sim_start, _sim_end): _cash_earnings = flex_select_nb(cash_earnings_raw_, i, col) if np.isnan(out[i, group]): out[i, group] = 0.0 if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[i, group] += _weights * _cash_earnings else: out[i, group] += _cash_earnings return out @register_jitted(cache=True) def get_free_cash_diff_nb( position_before: float, position_now: float, debt_now: float, price: float, fees: float, ) -> tp.Tuple[float, float]: """Get updated debt and free cash flow.""" size = add_nb(position_now, -position_before) final_cash = -size * price - fees if is_close_nb(size, 0): new_debt = debt_now free_cash_diff = 0.0 elif size > 0: if position_before < 0: if position_now < 0: short_size = abs(size) else: short_size = abs(position_before) avg_entry_price = debt_now / abs(position_before) debt_diff = short_size * avg_entry_price new_debt = add_nb(debt_now, -debt_diff) free_cash_diff = add_nb(2 * debt_diff, final_cash) else: new_debt = debt_now free_cash_diff = final_cash else: if position_now < 0: if position_before < 0: short_size = abs(size) else: short_size = abs(position_now) short_value = short_size * price new_debt = debt_now + short_value free_cash_diff = add_nb(final_cash, -2 * short_value) else: new_debt = debt_now free_cash_diff = final_cash return new_debt, free_cash_diff @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), free=None, cash_earnings=base_ch.FlexArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_flow_nb( target_shape: tp.Shape, order_records: tp.RecordArray, col_map: tp.GroupMap, free: bool = False, cash_earnings: tp.FlexArray2dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get (free) cash flow series per column.""" cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) out = np.full(target_shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for col in range(target_shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): out[i, col] = flex_select_nb(cash_earnings_, i, col) col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue col_len = col_lens[col] if col_len == 0: continue last_id = -1 position_now = 0.0 debt_now = 0.0 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] i = order_record["idx"] side = order_record["side"] size = order_record["size"] price = order_record["price"] fees = order_record["fees"] if side == OrderSide.Sell: size *= -1 position_before = position_now position_now = add_nb(position_now, size) if free: debt_now, cash_flow = get_free_cash_diff_nb( position_before=position_before, position_now=position_now, debt_now=debt_now, price=price, fees=fees, ) else: cash_flow = -size * price - fees out[i, col] = add_nb(out[i, col], cash_flow) return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( cash_flow=base_ch.array_gl_slicer, group_lens=ch.ArraySlicer(axis=0), sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_flow_grouped_nb( cash_flow: tp.Array2d, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get (free) cash flow series per group.""" out = np.full((cash_flow.shape[0], len(group_lens)), np.nan, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=cash_flow.shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if np.isnan(out[i, group]): out[i, group] = 0.0 out[i, group] += cash_flow[i, col] return out @register_chunkable( size=ch.ArraySizer(arg_query="free_cash_flow", axis=1), arg_take_spec=dict( init_cash_raw=None, free_cash_flow=ch.ArraySlicer(axis=1), cash_deposits=base_ch.FlexArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def align_init_cash_nb( init_cash_raw: int, free_cash_flow: tp.Array2d, cash_deposits: tp.FlexArray2dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Align initial cash to the maximum negative free cash flow per column or group.""" cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) out = np.full(free_cash_flow.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=free_cash_flow.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(free_cash_flow.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue free_cash = 0.0 min_req_cash = np.inf for i in range(_sim_start, _sim_end): free_cash = add_nb(free_cash, free_cash_flow[i, col]) free_cash = add_nb(free_cash, flex_select_nb(cash_deposits_, i, col)) if free_cash < min_req_cash: min_req_cash = free_cash if min_req_cash < 0: out[col] = np.abs(min_req_cash) else: out[col] = 1.0 if init_cash_raw == InitCashMode.AutoAlign: out = np.full(out.shape, np.max(out)) return out @register_jitted(cache=True) def init_cash_nb( init_cash_raw: tp.FlexArray1d, group_lens: tp.GroupLens, cash_sharing: bool, split_shared: bool = False, weights: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Get initial cash per column.""" out = np.empty(np.sum(group_lens), dtype=float_) if weights is None: weights_ = np.full(group_lens.sum(), np.nan, dtype=float_) else: weights_ = to_1d_array_nb(np.asarray(weights).astype(float_)) if not cash_sharing: for col in range(out.shape[0]): _init_cash = flex_select_1d_pc_nb(init_cash_raw, col) _weights = flex_select_1d_pc_nb(weights_, col) if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[col] = _weights * _init_cash else: out[col] = _init_cash else: from_col = 0 for group in range(len(group_lens)): to_col = from_col + group_lens[group] group_len = to_col - from_col _init_cash = flex_select_1d_pc_nb(init_cash_raw, group) for col in range(from_col, to_col): _weights = flex_select_1d_pc_nb(weights_, col) if split_shared: if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[col] = _weights * _init_cash / group_len else: out[col] = _init_cash / group_len else: if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): out[col] = _weights * _init_cash else: out[col] = _init_cash from_col = to_col return out @register_jitted(cache=True) def init_cash_grouped_nb( init_cash_raw: tp.FlexArray1d, group_lens: tp.GroupLens, cash_sharing: bool, weights: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Get initial cash per group.""" out = np.empty(group_lens.shape, dtype=float_) if weights is None: weights_ = np.full(group_lens.sum(), np.nan, dtype=float_) else: weights_ = to_1d_array_nb(np.asarray(weights).astype(float_)) if cash_sharing: from_col = 0 for group in range(len(group_lens)): to_col = from_col + group_lens[group] _init_cash = flex_select_1d_pc_nb(init_cash_raw, group) group_weight = 0.0 for col in range(from_col, to_col): _weights = flex_select_1d_pc_nb(weights_, col) if not np.isnan(group_weight) and not np.isnan(_weights): group_weight += _weights else: group_weight = np.nan break if not np.isnan(group_weight): group_weight /= group_lens[group] if not np.isnan(group_weight) and not is_close_nb(group_weight, 1.0): out[group] = group_weight * _init_cash else: out[group] = _init_cash from_col = to_col else: from_col = 0 for group in range(len(group_lens)): to_col = from_col + group_lens[group] cash_sum = 0.0 for col in range(from_col, to_col): _init_cash = flex_select_1d_pc_nb(init_cash_raw, col) _weights = flex_select_1d_pc_nb(weights_, col) if not np.isnan(_weights) and not is_close_nb(_weights, 1.0): cash_sum += _weights * _init_cash else: cash_sum += _init_cash out[group] = cash_sum from_col = to_col return out @register_chunkable( size=ch.ArraySizer(arg_query="cash_flow", axis=1), arg_take_spec=dict( cash_flow=ch.ArraySlicer(axis=1), init_cash=base_ch.FlexArraySlicer(), cash_deposits=base_ch.FlexArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cash_nb( cash_flow: tp.Array2d, init_cash: tp.FlexArray1d, cash_deposits: tp.FlexArray2dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get cash series per column or group.""" cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) out = np.full(cash_flow.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=cash_flow.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(cash_flow.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue cash_now = flex_select_1d_pc_nb(init_cash, col) for i in range(_sim_start, _sim_end): cash_now = add_nb(cash_now, flex_select_nb(cash_deposits_, i, col)) cash_now = add_nb(cash_now, cash_flow[i, col]) out[i, col] = cash_now return out # ############# Value ############# # @register_jitted(cache=True) def init_position_value_nb( n_cols: int, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get initial position value per column.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) out = np.empty(n_cols, dtype=float_) for col in range(n_cols): _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position == 0: out[col] = 0.0 else: out[col] = _init_position * _init_price return out @register_jitted(cache=True) def init_position_value_grouped_nb( group_lens: tp.GroupLens, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get initial position value per group.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) out = np.full(len(group_lens), 0.0, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: out[group] += _init_position * _init_price return out @register_jitted(cache=True) def init_value_nb(init_position_value: tp.Array1d, init_cash: tp.FlexArray1d) -> tp.Array1d: """Get initial value per column or group.""" out = np.empty(len(init_position_value), dtype=float_) for col in range(len(init_position_value)): _init_cash = flex_select_1d_pc_nb(init_cash, col) out[col] = _init_cash + init_position_value[col] return out @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), assets=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def asset_value_nb( close: tp.Array2d, assets: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get asset value series per column.""" out = np.full(close.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=close.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(close.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if assets[i, col] == 0: out[i, col] = 0.0 else: out[i, col] = close[i, col] * assets[i, col] return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( asset_value=base_ch.array_gl_slicer, group_lens=ch.ArraySlicer(axis=0), sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def asset_value_grouped_nb( asset_value: tp.Array2d, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get asset value series per group.""" out = np.full((asset_value.shape[0], len(group_lens)), np.nan, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=asset_value.shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if np.isnan(out[i, group]): out[i, group] = 0.0 out[i, group] += asset_value[i, col] return out @register_chunkable( size=ch.ArraySizer(arg_query="cash", axis=1), arg_take_spec=dict( cash=ch.ArraySlicer(axis=1), asset_value=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def value_nb( cash: tp.Array2d, asset_value: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get value series per column or group.""" out = np.full(cash.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=cash.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(cash.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): out[i, col] = cash[i, col] + asset_value[i, col] return out @register_chunkable( size=ch.ArraySizer(arg_query="asset_value", axis=1), arg_take_spec=dict( asset_value=ch.ArraySlicer(axis=1), value=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def gross_exposure_nb( asset_value: tp.Array2d, value: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get gross exposure series per column.""" out = np.full(asset_value.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=asset_value.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(asset_value.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if value[i, col] == 0: out[i, col] = np.nan else: out[i, col] = abs(asset_value[i, col] / value[i, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="long_exposure", axis=1), arg_take_spec=dict( long_exposure=ch.ArraySlicer(axis=1), short_exposure=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def net_exposure_nb( long_exposure: tp.Array2d, short_exposure: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get net exposure series per column.""" out = np.full(long_exposure.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=long_exposure.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(long_exposure.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): out[i, col] = long_exposure[i, col] - short_exposure[i, col] return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( asset_value=base_ch.array_gl_slicer, value=ch.ArraySlicer(axis=1), group_lens=ch.ArraySlicer(axis=0), sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def allocations_nb( asset_value: tp.Array2d, value: tp.Array2d, group_lens: tp.GroupLens, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get allocations per column.""" out = np.full(asset_value.shape, np.nan, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=asset_value.shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): if value[i, group] == 0: out[i, col] = np.nan else: out[i, col] = asset_value[i, col] / value[i, group] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), close=ch.ArraySlicer(axis=1), order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), init_position=base_ch.FlexArraySlicer(), init_price=base_ch.FlexArraySlicer(), cash_earnings=base_ch.FlexArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def total_profit_nb( target_shape: tp.Shape, close: tp.Array2d, order_records: tp.RecordArray, col_map: tp.GroupMap, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_earnings: tp.FlexArray2dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Get total profit per column. A much faster version than the one based on `value_nb`.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) assets = np.full(target_shape[1], 0.0, dtype=float_) cash = np.full(target_shape[1], 0.0, dtype=float_) total_profit = np.full(target_shape[1], np.nan, dtype=float_) col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(target_shape[1]): _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: assets[col] = _init_position cash[col] = -_init_position * _init_price _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue for i in range(_sim_start, _sim_end): cash[col] += flex_select_nb(cash_earnings_, i, col) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue col_len = col_lens[col] if col_len == 0: if assets[col] == 0 and cash[col] == 0: total_profit[col] = 0.0 continue last_id = -1 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] # Fill assets if order_record["side"] == OrderSide.Buy: order_size = order_record["size"] assets[col] = add_nb(assets[col], order_size) else: order_size = order_record["size"] assets[col] = add_nb(assets[col], -order_size) # Fill cash balance if order_record["side"] == OrderSide.Buy: order_cash = order_record["size"] * order_record["price"] + order_record["fees"] cash[col] = add_nb(cash[col], -order_cash) else: order_cash = order_record["size"] * order_record["price"] - order_record["fees"] cash[col] = add_nb(cash[col], order_cash) total_profit[col] = cash[col] + assets[col] * close[_sim_end - 1, col] return total_profit @register_jitted(cache=True) def total_profit_grouped_nb(total_profit: tp.Array1d, group_lens: tp.GroupLens) -> tp.Array1d: """Get total profit per group.""" out = np.empty(len(group_lens), dtype=float_) from_col = 0 for group in range(len(group_lens)): to_col = from_col + group_lens[group] out[group] = np.sum(total_profit[from_col:to_col]) from_col = to_col return out @register_chunkable( size=ch.ArraySizer(arg_query="value", axis=1), arg_take_spec=dict( value=ch.ArraySlicer(axis=1), init_value=base_ch.FlexArraySlicer(), cash_deposits=base_ch.FlexArraySlicer(axis=1), cash_deposits_as_input=None, log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def returns_nb( value: tp.Array2d, init_value: tp.FlexArray1d, cash_deposits: tp.FlexArray2dLike = 0.0, cash_deposits_as_input: bool = False, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get return series per column or group.""" cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) out = np.full(value.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=value.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(value.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue input_value = flex_select_1d_pc_nb(init_value, col) for i in range(_sim_start, _sim_end): _cash_deposits = flex_select_nb(cash_deposits_, i, col) output_value = value[i, col] if cash_deposits_as_input: adj_input_value = input_value + _cash_deposits out[i, col] = returns_nb_.get_return_nb(adj_input_value, output_value, log_returns=log_returns) else: adj_output_value = output_value - _cash_deposits out[i, col] = returns_nb_.get_return_nb(input_value, adj_output_value, log_returns=log_returns) input_value = output_value return out @register_jitted(cache=True) def get_asset_pnl_nb( input_asset_value: float, output_asset_value: float, cash_flow: float, ) -> float: """Get asset PnL from the input and output asset value, and the cash flow.""" return output_asset_value + cash_flow - input_asset_value @register_chunkable( size=ch.ArraySizer(arg_query="asset_value", axis=1), arg_take_spec=dict( asset_value=ch.ArraySlicer(axis=1), cash_flow=ch.ArraySlicer(axis=1), init_position_value=base_ch.FlexArraySlicer(axis=0), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def asset_pnl_nb( asset_value: tp.Array2d, cash_flow: tp.Array2d, init_position_value: tp.FlexArray1dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get asset (realized and unrealized) PnL series per column or group.""" init_position_value_ = to_1d_array_nb(np.asarray(init_position_value)) out = np.full(asset_value.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=asset_value.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(asset_value.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position_value = flex_select_1d_pc_nb(init_position_value_, col) for i in range(_sim_start, _sim_end): if i == _sim_start: input_asset_value = _init_position_value else: input_asset_value = asset_value[i - 1, col] out[i, col] = get_asset_pnl_nb( input_asset_value, asset_value[i, col], cash_flow[i, col], ) return out @register_jitted(cache=True) def get_asset_return_nb( input_asset_value: float, output_asset_value: float, cash_flow: float, log_returns: bool = False, ) -> float: """Get asset return from the input and output asset value, and the cash flow.""" if is_close_nb(input_asset_value, 0): input_value = -output_asset_value output_value = cash_flow else: input_value = input_asset_value output_value = output_asset_value + cash_flow if input_value < 0 and output_value < 0: return_value = -returns_nb_.get_return_nb(-input_value, -output_value, log_returns=False) else: return_value = returns_nb_.get_return_nb(input_value, output_value, log_returns=False) if log_returns: return np.log1p(return_value) return return_value @register_chunkable( size=ch.ArraySizer(arg_query="asset_value", axis=1), arg_take_spec=dict( asset_value=ch.ArraySlicer(axis=1), cash_flow=ch.ArraySlicer(axis=1), init_position_value=base_ch.FlexArraySlicer(axis=0), log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def asset_returns_nb( asset_value: tp.Array2d, cash_flow: tp.Array2d, init_position_value: tp.FlexArray1dLike = 0.0, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get asset return series per column or group.""" init_position_value_ = to_1d_array_nb(np.asarray(init_position_value)) out = np.full(asset_value.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=asset_value.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(asset_value.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position_value = flex_select_1d_pc_nb(init_position_value_, col) for i in range(_sim_start, _sim_end): if i == _sim_start: input_asset_value = _init_position_value else: input_asset_value = asset_value[i - 1, col] out[i, col] = get_asset_return_nb( input_asset_value, asset_value[i, col], cash_flow[i, col], log_returns=log_returns, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="close", axis=1), arg_take_spec=dict( close=ch.ArraySlicer(axis=1), init_value=base_ch.FlexArraySlicer(), cash_deposits=base_ch.FlexArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def market_value_nb( close: tp.Array2d, init_value: tp.FlexArray1d, cash_deposits: tp.FlexArray2dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get market value per column.""" cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) out = np.full(close.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=close.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(close.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue curr_value = flex_select_1d_pc_nb(init_value, col) for i in range(_sim_start, _sim_end): if i > _sim_start: curr_value *= close[i, col] / close[i - 1, col] curr_value += flex_select_nb(cash_deposits_, i, col) out[i, col] = curr_value return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( close=base_ch.array_gl_slicer, group_lens=ch.ArraySlicer(axis=0), init_value=base_ch.FlexArraySlicer(mapper=base_ch.group_lens_mapper), cash_deposits=base_ch.FlexArraySlicer(axis=1, mapper=base_ch.group_lens_mapper), sim_start=base_ch.flex_1d_array_gl_slicer, sim_end=base_ch.flex_1d_array_gl_slicer, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def market_value_grouped_nb( close: tp.Array2d, group_lens: tp.GroupLens, init_value: tp.FlexArray1d, cash_deposits: tp.FlexArray2dLike = 0.0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get market value per group.""" cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) out = np.full((close.shape[0], len(group_lens)), np.nan, dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=close.shape, sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] for col in range(from_col, to_col): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue curr_value = prev_value = flex_select_1d_pc_nb(init_value, col) for i in range(_sim_start, _sim_end): if i > _sim_start: if not np.isnan(close[i - 1, col]): prev_close = close[i - 1, col] prev_value = prev_close else: prev_close = prev_value if not np.isnan(close[i, col]): curr_close = close[i, col] prev_value = curr_close else: curr_close = prev_value curr_value *= curr_close / prev_close curr_value += flex_select_nb(cash_deposits_, i, col) if np.isnan(out[i, group]): out[i, group] = 0.0 out[i, group] += curr_value return out @register_jitted(cache=True) def total_market_return_nb( market_value: tp.Array2d, input_value: tp.FlexArray1d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """Get total market return per column or group.""" out = np.full(market_value.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=market_value.shape, sim_start=sim_start, sim_end=sim_end, ) for col in range(market_value.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _input_value = flex_select_1d_pc_nb(input_value, col) if _input_value != 0: out[col] = (market_value[_sim_end - 1, col] - _input_value) / _input_value return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Core Numba-compiled functions for portfolio simulation.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb, flex_select_nb from vectorbtpro.generic import nb as generic_nb from vectorbtpro.portfolio.enums import * from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils.math_ import is_close_nb, is_close_or_less_nb, is_close_or_greater_nb, is_less_nb, add_nb @register_jitted(cache=True) def order_not_filled_nb(status: int, status_info: int) -> OrderResult: """Return `OrderResult` for order that hasn't been filled.""" return OrderResult(size=np.nan, price=np.nan, fees=np.nan, side=-1, status=status, status_info=status_info) @register_jitted(cache=True) def check_adj_price_nb( adj_price: float, price_area: PriceArea, is_closing_price: bool, price_area_vio_mode: int, ) -> float: """Check whether adjusted price is within price boundaries.""" if price_area_vio_mode == PriceAreaVioMode.Ignore: return adj_price if adj_price > price_area.high: if price_area_vio_mode == PriceAreaVioMode.Error: raise ValueError("Adjusted order price is above the highest price") elif price_area_vio_mode == PriceAreaVioMode.Cap: adj_price = price_area.high if adj_price < price_area.low: if price_area_vio_mode == PriceAreaVioMode.Error: raise ValueError("Adjusted order price is below the lowest price") elif price_area_vio_mode == PriceAreaVioMode.Cap: adj_price = price_area.low if is_closing_price and adj_price != price_area.close: if price_area_vio_mode == PriceAreaVioMode.Error: raise ValueError("Adjusted order price is beyond the closing price") elif price_area_vio_mode == PriceAreaVioMode.Cap: adj_price = price_area.close return adj_price @register_jitted(cache=True) def approx_long_buy_value_nb(val_price: float, size: float) -> float: """Approximate value of a long-buy operation. Positive value means spending (for sorting reasons).""" if size == 0: return 0.0 order_value = abs(size) * val_price add_free_cash = -order_value return -add_free_cash @register_jitted(cache=True) def should_apply_size_granularity_nb(size: float, size_granularity: float) -> bool: """Whether to apply a size granularity to a size.""" if np.isnan(size_granularity): return False if size_granularity % 1 == 0: return True adj_size = size // size_granularity * size_granularity return not is_close_nb(size, adj_size) and not is_close_nb(size, adj_size + size_granularity) @register_jitted(cache=True) def apply_size_granularity_nb(size: float, size_granularity: float) -> float: """Apply a size granularity to a size.""" return size // size_granularity * size_granularity @register_jitted(cache=True) def cast_account_state_nb(account_state: AccountState) -> AccountState: """Cast account state to float.""" return AccountState( cash=float(account_state.cash), position=float(account_state.position), debt=float(account_state.debt), locked_cash=float(account_state.locked_cash), free_cash=float(account_state.free_cash), ) @register_jitted(cache=True) def long_buy_nb( account_state: AccountState, size: float, price: float, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, leverage: float = 1.0, leverage_mode: int = LeverageMode.Lazy, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, percent: float = np.nan, price_area: PriceArea = NoPriceArea, is_closing_price: bool = False, ) -> tp.Tuple[OrderResult, AccountState]: """Open or increase a long position.""" _account_state = cast_account_state_nb(account_state) # Get cash limit cash_limit = _account_state.free_cash if not np.isnan(percent): cash_limit = cash_limit * percent if cash_limit <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoCash), _account_state cash_limit = cash_limit * leverage # Adjust for granularity if should_apply_size_granularity_nb(size, size_granularity): size = apply_size_granularity_nb(size, size_granularity) # Adjust for max size if not np.isnan(max_size) and size > max_size: if not allow_partial: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state size = max_size if np.isinf(size) and np.isinf(cash_limit): raise ValueError("Attempt to go in long direction infinitely") # Get price adjusted with slippage adj_price = price * (1 + slippage) adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode) # Get cash required to complete this order if np.isinf(size): req_cash = np.inf req_fees = np.inf else: order_value = size * adj_price req_fees = order_value * fees + fixed_fees req_cash = order_value + req_fees if is_close_or_less_nb(req_cash, cash_limit): # Sufficient amount of cash final_size = size fees_paid = req_fees else: # Insufficient amount of cash, size will be less than requested # For fees of 10% and 1$ per transaction, you can buy for 90$ (new_req_cash) # to spend 100$ (cash_limit) in total max_req_cash = add_nb(cash_limit, -fixed_fees) / (1 + fees) if max_req_cash <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state max_acq_size = max_req_cash / adj_price # Adjust for granularity if should_apply_size_granularity_nb(max_acq_size, size_granularity): final_size = apply_size_granularity_nb(max_acq_size, size_granularity) new_order_value = final_size * adj_price fees_paid = new_order_value * fees + fixed_fees req_cash = new_order_value + fees_paid else: final_size = max_acq_size fees_paid = cash_limit - max_req_cash req_cash = cash_limit # Check against size of zero if is_close_nb(final_size, 0): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state # Check against minimum size if not np.isnan(min_size) and is_less_nb(final_size, min_size): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state # Check against partial fill (np.inf doesn't count) if np.isfinite(size) and is_less_nb(final_size, size) and not allow_partial: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state # Create a filled order order_result = OrderResult( float(final_size), float(adj_price), float(fees_paid), OrderSide.Buy, OrderStatus.Filled, -1, ) # Update the current account state new_cash = add_nb(_account_state.cash, -req_cash) new_position = add_nb(_account_state.position, final_size) if leverage_mode == LeverageMode.Lazy: debt_diff = max(add_nb(req_cash, -_account_state.free_cash), 0.0) if debt_diff > 0: new_debt = _account_state.debt + debt_diff new_locked_cash = _account_state.locked_cash + _account_state.free_cash new_free_cash = 0.0 else: new_debt = _account_state.debt new_locked_cash = _account_state.locked_cash new_free_cash = add_nb(_account_state.free_cash, -req_cash) else: if leverage > 1: if np.isinf(leverage): raise ValueError("Leverage must be finite for LeverageMode.Eager") order_value = final_size * adj_price new_debt = _account_state.debt + order_value * (leverage - 1) / leverage new_locked_cash = _account_state.locked_cash + order_value / leverage new_free_cash = add_nb(_account_state.free_cash, -order_value / leverage - fees_paid) else: new_debt = _account_state.debt new_locked_cash = _account_state.locked_cash new_free_cash = add_nb(_account_state.free_cash, -req_cash) new_account_state = AccountState( cash=float(new_cash), position=float(new_position), debt=float(new_debt), locked_cash=float(new_locked_cash), free_cash=float(new_free_cash), ) return order_result, new_account_state @register_jitted(cache=True) def approx_long_sell_value_nb(position: float, debt: float, val_price: float, size: float) -> float: """Approximate value of a long-sell operation. Positive value means spending (for sorting reasons).""" if size == 0 or position == 0: return 0.0 size_limit = min(abs(size), position) order_value = size_limit * val_price size_fraction = size_limit / position released_debt = size_fraction * debt add_free_cash = order_value - released_debt return -add_free_cash @register_jitted(cache=True) def long_sell_nb( account_state: AccountState, size: float, price: float, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, percent: float = np.nan, price_area: PriceArea = NoPriceArea, is_closing_price: bool = False, ) -> tp.Tuple[OrderResult, AccountState]: """Decrease or close a long position.""" _account_state = cast_account_state_nb(account_state) # Check for open position if _account_state.position == 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoOpenPosition), _account_state # Get size limit size_limit = min(size, _account_state.position) if not np.isnan(percent): size_limit = size_limit * percent # Adjust for granularity if should_apply_size_granularity_nb(size_limit, size_granularity): size = apply_size_granularity_nb(size, size_granularity) size_limit = apply_size_granularity_nb(size_limit, size_granularity) # Adjust for max size if not np.isnan(max_size) and size_limit > max_size: if not allow_partial: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state size_limit = max_size # Check against size of zero if is_close_nb(size_limit, 0): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state # Check against minimum size if not np.isnan(min_size) and is_less_nb(size_limit, min_size): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state # Check against partial fill if np.isfinite(size) and is_less_nb(size_limit, size) and not allow_partial: # np.inf doesn't count return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state # Get price adjusted with slippage adj_price = price * (1 - slippage) adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode) # Get acquired cash acq_cash = size_limit * adj_price # Update fees fees_paid = acq_cash * fees + fixed_fees # Get final cash by subtracting costs final_acq_cash = add_nb(acq_cash, -fees_paid) if final_acq_cash < 0 and is_less_nb(_account_state.free_cash, -final_acq_cash): return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state # Create a filled order order_result = OrderResult( float(size_limit), float(adj_price), float(fees_paid), OrderSide.Sell, OrderStatus.Filled, -1, ) # Update the current account state new_cash = _account_state.cash + final_acq_cash new_position = add_nb(_account_state.position, -size_limit) new_pos_fraction = abs(new_position) / abs(_account_state.position) new_debt = new_pos_fraction * _account_state.debt new_locked_cash = new_pos_fraction * _account_state.locked_cash size_fraction = size_limit / _account_state.position released_debt = size_fraction * _account_state.debt new_free_cash = add_nb(_account_state.free_cash, final_acq_cash - released_debt) new_account_state = AccountState( cash=float(new_cash), position=float(new_position), debt=float(new_debt), locked_cash=float(new_locked_cash), free_cash=float(new_free_cash), ) return order_result, new_account_state @register_jitted(cache=True) def approx_short_sell_value_nb(val_price: float, size: float) -> float: """Approximate value of a short-sell operation. Positive value means spending (for sorting reasons).""" if size == 0: return 0.0 order_value = abs(size) * val_price add_free_cash = -order_value return -add_free_cash @register_jitted(cache=True) def short_sell_nb( account_state: AccountState, size: float, price: float, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, leverage: float = 1.0, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, percent: float = np.nan, price_area: PriceArea = NoPriceArea, is_closing_price: bool = False, ) -> tp.Tuple[OrderResult, AccountState]: """Open or increase a short position.""" _account_state = cast_account_state_nb(account_state) # Get cash limit cash_limit = _account_state.free_cash if not np.isnan(percent): cash_limit = cash_limit * percent if cash_limit <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoCash), _account_state cash_limit = cash_limit * leverage # Get price adjusted with slippage adj_price = price * (1 - slippage) adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode) # Get size limit fees_adj_price = adj_price * (1 + fees) if fees_adj_price == 0: max_size_limit = np.inf else: max_size_limit = add_nb(cash_limit, -fixed_fees) / (adj_price * (1 + fees)) size_limit = min(size, max_size_limit) if size_limit <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state # Adjust for granularity if should_apply_size_granularity_nb(size_limit, size_granularity): size = apply_size_granularity_nb(size, size_granularity) size_limit = apply_size_granularity_nb(size_limit, size_granularity) # Adjust for max size if not np.isnan(max_size) and size_limit > max_size: if not allow_partial: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state size_limit = max_size if np.isinf(size_limit): raise ValueError("Attempt to go in short direction infinitely") # Check against size of zero if is_close_nb(size_limit, 0): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state # Check against minimum size if not np.isnan(min_size) and is_less_nb(size_limit, min_size): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state # Check against partial fill if np.isfinite(size) and is_less_nb(size_limit, size) and not allow_partial: # np.inf doesn't count return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state # Get acquired cash order_value = size_limit * adj_price # Update fees fees_paid = order_value * fees + fixed_fees # Get final cash by subtracting costs final_acq_cash = add_nb(order_value, -fees_paid) if final_acq_cash < 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state # Create a filled order order_result = OrderResult( float(size_limit), float(adj_price), float(fees_paid), OrderSide.Sell, OrderStatus.Filled, -1, ) # Update the current account state new_cash = _account_state.cash + final_acq_cash new_position = _account_state.position - size_limit new_debt = _account_state.debt + order_value if np.isinf(leverage): if np.isinf(_account_state.free_cash): raise ValueError("Leverage must be finite when _account_state.free_cash is infinite") if is_close_or_less_nb(_account_state.free_cash, fees_paid): return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state leverage_ = order_value / (_account_state.free_cash - fees_paid) else: leverage_ = float(leverage) new_locked_cash = _account_state.locked_cash + order_value / leverage_ new_free_cash = add_nb(_account_state.free_cash, -order_value / leverage_ - fees_paid) new_account_state = AccountState( cash=float(new_cash), position=float(new_position), debt=float(new_debt), locked_cash=float(new_locked_cash), free_cash=float(new_free_cash), ) return order_result, new_account_state @register_jitted(cache=True) def approx_short_buy_value_nb(position: float, debt: float, locked_cash: float, val_price: float, size: float) -> float: """Approximate value of a short-buy operation. Positive value means spending (for sorting reasons).""" if size == 0 or position == 0: return 0.0 size_limit = min(abs(size), abs(position)) order_value = size_limit * val_price size_fraction = size_limit / abs(position) released_debt = size_fraction * debt released_cash = size_fraction * locked_cash add_free_cash = released_cash + released_debt - order_value return -add_free_cash @register_jitted(cache=True) def short_buy_nb( account_state: AccountState, size: float, price: float, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, percent: float = np.nan, price_area: PriceArea = NoPriceArea, is_closing_price: bool = False, ) -> tp.Tuple[OrderResult, AccountState]: """Decrease or close a short position.""" _account_state = cast_account_state_nb(account_state) # Check for open position if _account_state.position == 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoOpenPosition), _account_state # Get cash limit cash_limit = _account_state.free_cash + _account_state.debt + _account_state.locked_cash if cash_limit <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoCash), _account_state # Get size limit size_limit = min(size, abs(_account_state.position)) if not np.isnan(percent): size_limit = size_limit * percent # Adjust for granularity if should_apply_size_granularity_nb(size_limit, size_granularity): size_limit = apply_size_granularity_nb(size_limit, size_granularity) # Adjust for max size if not np.isnan(max_size) and size_limit > max_size: if not allow_partial: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state size_limit = max_size # Get price adjusted with slippage adj_price = price * (1 + slippage) adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode) # Get cash required to complete this order if np.isinf(size_limit): req_cash = np.inf req_fees = np.inf else: order_value = size_limit * adj_price req_fees = order_value * fees + fixed_fees req_cash = order_value + req_fees if is_close_or_less_nb(req_cash, cash_limit): # Sufficient amount of cash final_size = size_limit fees_paid = req_fees else: # Insufficient amount of cash, size will be less than requested # For fees of 10% and 1$ per transaction, you can buy for 90$ (new_req_cash) # to spend 100$ (cash_limit) in total max_req_cash = add_nb(cash_limit, -fixed_fees) / (1 + fees) if max_req_cash <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state max_acq_size = max_req_cash / adj_price # Adjust for granularity if should_apply_size_granularity_nb(max_acq_size, size_granularity): final_size = apply_size_granularity_nb(max_acq_size, size_granularity) new_order_value = final_size * adj_price fees_paid = new_order_value * fees + fixed_fees req_cash = new_order_value + fees_paid else: final_size = max_acq_size fees_paid = cash_limit - max_req_cash req_cash = cash_limit # Check size of zero if is_close_nb(final_size, 0): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state # Check against minimum size if not np.isnan(min_size) and is_less_nb(final_size, min_size): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state # Check against partial fill (np.inf doesn't count) if np.isfinite(size_limit) and is_less_nb(final_size, size_limit) and not allow_partial: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state # Create a filled order order_result = OrderResult( float(final_size), float(adj_price), float(fees_paid), OrderSide.Buy, OrderStatus.Filled, -1, ) # Update the current account state new_cash = add_nb(_account_state.cash, -req_cash) new_position = add_nb(_account_state.position, final_size) new_pos_fraction = abs(new_position) / abs(_account_state.position) new_debt = new_pos_fraction * _account_state.debt new_locked_cash = new_pos_fraction * _account_state.locked_cash size_fraction = final_size / abs(_account_state.position) released_debt = size_fraction * _account_state.debt released_cash = size_fraction * _account_state.locked_cash new_free_cash = add_nb(_account_state.free_cash, released_cash + released_debt - req_cash) new_account_state = AccountState( cash=float(new_cash), position=float(new_position), debt=float(new_debt), locked_cash=float(new_locked_cash), free_cash=float(new_free_cash), ) return order_result, new_account_state @register_jitted(cache=True) def approx_buy_value_nb( position: float, debt: float, locked_cash: float, val_price: float, size: float, direction: int, ) -> float: """Approximate value of a buy operation. Positive value means spending (for sorting reasons).""" if position <= 0 and direction == Direction.ShortOnly: return approx_short_buy_value_nb(position, debt, locked_cash, val_price, size) if position >= 0: return approx_long_buy_value_nb(val_price, size) value1 = approx_short_buy_value_nb(position, debt, locked_cash, val_price, size) new_size = add_nb(size, -abs(position)) if new_size <= 0: return value1 value2 = approx_long_buy_value_nb(val_price, new_size) return value1 + value2 @register_jitted(cache=True) def buy_nb( account_state: AccountState, size: float, price: float, direction: int = Direction.Both, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, leverage: float = 1.0, leverage_mode: int = LeverageMode.Lazy, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, percent: float = np.nan, price_area: PriceArea = NoPriceArea, is_closing_price: bool = False, ) -> tp.Tuple[OrderResult, AccountState]: """Buy.""" _account_state = cast_account_state_nb(account_state) if _account_state.position <= 0 and direction == Direction.ShortOnly: return short_buy_nb( account_state=_account_state, size=size, price=price, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size, max_size=max_size, size_granularity=size_granularity, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) if _account_state.position >= 0: return long_buy_nb( account_state=_account_state, size=size, price=price, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size, max_size=max_size, size_granularity=size_granularity, leverage=leverage, leverage_mode=leverage_mode, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) short_size = min(size, abs(_account_state.position)) if not np.isnan(min_size): min_size1 = min(min_size, abs(_account_state.position)) else: min_size1 = np.nan if not np.isnan(max_size): max_size1 = min(max_size, abs(_account_state.position)) else: max_size1 = np.nan new_order_result1, new_account_state1 = short_buy_nb( account_state=_account_state, size=short_size, price=price, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size1, max_size=max_size1, size_granularity=size_granularity, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=np.nan, price_area=price_area, is_closing_price=is_closing_price, ) if new_order_result1.status != OrderStatus.Filled: return new_order_result1, _account_state if new_account_state1.position != 0: return new_order_result1, new_account_state1 new_size = add_nb(size, -abs(_account_state.position)) if new_size <= 0: return new_order_result1, new_account_state1 if not np.isnan(min_size): min_size2 = max(min_size - abs(_account_state.position), 0.0) else: min_size2 = np.nan if not np.isnan(max_size): max_size2 = max(max_size - abs(_account_state.position), 0.0) else: max_size2 = np.nan new_order_result2, new_account_state2 = long_buy_nb( account_state=new_account_state1, size=new_size, price=price, fees=fees, fixed_fees=0.0, slippage=slippage, min_size=min_size2, max_size=max_size2, size_granularity=size_granularity, leverage=leverage, leverage_mode=leverage_mode, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) if new_order_result2.status != OrderStatus.Filled: if allow_partial or np.isinf(new_size): if new_order_result2.status_info == OrderStatusInfo.SizeZero: return new_order_result1, new_account_state1 if new_order_result2.status_info == OrderStatusInfo.NoCash: return new_order_result1, new_account_state1 return new_order_result2, _account_state new_order_result = OrderResult( new_order_result1.size + new_order_result2.size, new_order_result2.price, new_order_result1.fees + new_order_result2.fees, new_order_result2.side, new_order_result2.status, new_order_result2.status_info, ) return new_order_result, new_account_state2 @register_jitted(cache=True) def approx_sell_value_nb( position: float, debt: float, val_price: float, size: float, direction: int, ) -> float: """Approximate value of a sell operation. Positive value means spending (for sorting reasons).""" if position >= 0 and direction == Direction.LongOnly: return approx_long_sell_value_nb(position, debt, val_price, size) if position <= 0: return approx_short_sell_value_nb(val_price, size) value1 = approx_long_sell_value_nb(position, debt, val_price, size) new_size = add_nb(size, -abs(position)) if new_size <= 0: return value1 value2 = approx_short_sell_value_nb(val_price, new_size) return value1 + value2 @register_jitted(cache=True) def sell_nb( account_state: AccountState, size: float, price: float, direction: int = Direction.Both, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, leverage: float = 1.0, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, percent: float = np.nan, price_area: PriceArea = NoPriceArea, is_closing_price: bool = False, ) -> tp.Tuple[OrderResult, AccountState]: """Sell.""" _account_state = cast_account_state_nb(account_state) if _account_state.position >= 0 and direction == Direction.LongOnly: return long_sell_nb( account_state=_account_state, size=size, price=price, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size, max_size=max_size, size_granularity=size_granularity, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) if _account_state.position <= 0: return short_sell_nb( account_state=_account_state, size=size, price=price, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size, max_size=max_size, size_granularity=size_granularity, leverage=leverage, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) long_size = min(size, _account_state.position) if not np.isnan(min_size): min_size1 = min(min_size, _account_state.position) else: min_size1 = np.nan if not np.isnan(max_size): max_size1 = min(max_size, _account_state.position) else: max_size1 = np.nan new_order_result1, new_account_state1 = long_sell_nb( account_state=_account_state, size=long_size, price=price, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size1, max_size=max_size1, size_granularity=size_granularity, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=np.nan, price_area=price_area, is_closing_price=is_closing_price, ) if new_order_result1.status != OrderStatus.Filled: return new_order_result1, _account_state if new_account_state1.position != 0: return new_order_result1, new_account_state1 new_size = add_nb(size, -abs(_account_state.position)) if new_size <= 0: return new_order_result1, new_account_state1 if not np.isnan(min_size): min_size2 = max(min_size - _account_state.position, 0.0) else: min_size2 = np.nan if not np.isnan(max_size): max_size2 = max(max_size - _account_state.position, 0.0) else: max_size2 = np.nan new_order_result2, new_account_state2 = short_sell_nb( account_state=new_account_state1, size=new_size, price=price, fees=fees, fixed_fees=0.0, slippage=slippage, min_size=min_size2, max_size=max_size2, size_granularity=size_granularity, leverage=leverage, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) if new_order_result2.status != OrderStatus.Filled: if allow_partial or np.isinf(new_size): if new_order_result2.status_info == OrderStatusInfo.SizeZero: return new_order_result1, new_account_state1 if new_order_result2.status_info == OrderStatusInfo.NoCash: return new_order_result1, new_account_state1 return new_order_result2, _account_state new_order_result = OrderResult( new_order_result1.size + new_order_result2.size, new_order_result2.price, new_order_result1.fees + new_order_result2.fees, new_order_result2.side, new_order_result2.status, new_order_result2.status_info, ) return new_order_result, new_account_state2 @register_jitted(cache=True) def update_value_nb( cash_before: float, cash_now: float, position_before: float, position_now: float, val_price_before: float, val_price_now: float, value_before: float, ) -> float: """Update valuation price and value.""" cash_flow = cash_now - cash_before if position_before != 0: asset_value_before = position_before * val_price_before else: asset_value_before = 0.0 if position_now != 0: asset_value_now = position_now * val_price_now else: asset_value_now = 0.0 asset_value_diff = asset_value_now - asset_value_before value_now = value_before + cash_flow + asset_value_diff return value_now @register_jitted(cache=True) def get_diraware_size_nb(size: float, direction: int) -> float: """Get direction-aware size.""" if direction == Direction.ShortOnly: return size * -1 return size @register_jitted(cache=True) def resolve_size_nb( size: float, size_type: int, position: float, val_price: float, value: float, target_size_type: int = SizeType.Amount, as_requirement: bool = False, ) -> tp.Tuple[float, float]: """Resolve size into an absolute amount of assets and percentage of resources. Percentage is only set if the option `SizeType.Percent(100)` is used.""" percent = np.nan if size_type == target_size_type: return float(size), percent if size_type == SizeType.ValuePercent100: if size_type == target_size_type: return float(size), percent size /= 100 size_type = SizeType.ValuePercent if size_type == SizeType.TargetPercent100: if size_type == target_size_type: return float(size), percent size /= 100 size_type = SizeType.TargetPercent if size_type == SizeType.ValuePercent or size_type == SizeType.TargetPercent: if size_type == target_size_type: return float(size), percent size *= value if size_type == SizeType.ValuePercent: size_type = SizeType.Value else: size_type = SizeType.TargetValue if size_type == SizeType.Value or size_type == SizeType.TargetValue: if size_type == target_size_type: return float(size), percent size /= val_price if size_type == SizeType.Value: size_type = SizeType.Amount else: size_type = SizeType.TargetAmount if size_type == SizeType.TargetAmount: if size_type == target_size_type: return float(size), percent if not as_requirement: size -= position size_type = SizeType.Amount if size_type == SizeType.Percent100: if size_type == target_size_type: return float(size), percent size /= 100 size_type = SizeType.Percent if size_type == SizeType.Percent: if size_type == target_size_type: return float(size), percent percent = abs(size) if as_requirement: size = np.nan else: size = np.sign(size) * np.inf size_type = SizeType.Amount if size_type != target_size_type: raise ValueError("Cannot convert size to target size type") if as_requirement: size = abs(size) return float(size), percent @register_jitted(cache=True) def approx_order_value_nb( exec_state: ExecState, size: float, size_type: int = SizeType.Amount, direction: int = Direction.Both, ) -> float: """Approximate the value of an order. Assumes that cash is infinite. Positive value means spending (for sorting reasons).""" size = get_diraware_size_nb(float(size), direction) amount_size, _ = resolve_size_nb( size=size, size_type=size_type, position=exec_state.position, val_price=exec_state.val_price, value=exec_state.value, ) if amount_size >= 0: order_value = approx_buy_value_nb( position=exec_state.position, debt=exec_state.debt, locked_cash=exec_state.locked_cash, val_price=exec_state.val_price, size=abs(amount_size), direction=direction, ) else: order_value = approx_sell_value_nb( position=exec_state.position, debt=exec_state.debt, val_price=exec_state.val_price, size=abs(amount_size), direction=direction, ) return order_value @register_jitted(cache=True) def execute_order_nb( exec_state: ExecState, order: Order, price_area: PriceArea = NoPriceArea, update_value: bool = False, ) -> tp.Tuple[OrderResult, ExecState]: """Execute an order given the current state. Args: exec_state (ExecState): See `vectorbtpro.portfolio.enums.ExecState`. order (Order): See `vectorbtpro.portfolio.enums.Order`. price_area (OrderPriceArea): See `vectorbtpro.portfolio.enums.PriceArea`. update_value (bool): Whether to update the value. Error is thrown if an input has value that is not expected. Order is ignored if its execution has no effect on the current balance. Order is rejected if an input goes over a limit or against a restriction. """ # numerical stability cash = float(exec_state.cash) if is_close_nb(cash, 0): cash = 0.0 position = float(exec_state.position) if is_close_nb(position, 0): position = 0.0 debt = float(exec_state.debt) if is_close_nb(debt, 0): debt = 0.0 locked_cash = float(exec_state.locked_cash) if is_close_nb(locked_cash, 0): locked_cash = 0.0 free_cash = float(exec_state.free_cash) if is_close_nb(free_cash, 0): free_cash = 0.0 val_price = float(exec_state.val_price) if is_close_nb(val_price, 0): val_price = 0.0 value = float(exec_state.value) if is_close_nb(value, 0): value = 0.0 # Pre-fill account state account_state = AccountState( cash=cash, position=position, debt=debt, locked_cash=locked_cash, free_cash=free_cash, ) # Check price area if np.isinf(price_area.open) or price_area.open < 0: raise ValueError("price_area.open must be either NaN, or finite and 0 or greater") if np.isinf(price_area.high) or price_area.high < 0: raise ValueError("price_area.high must be either NaN, or finite and 0 or greater") if np.isinf(price_area.low) or price_area.low < 0: raise ValueError("price_area.low must be either NaN, or finite and 0 or greater") if np.isinf(price_area.close) or price_area.close < 0: raise ValueError("price_area.close must be either NaN, or finite and 0 or greater") # Resolve price order_price = order.price is_closing_price = False if np.isinf(order_price): if order_price > 0: order_price = price_area.close is_closing_price = True else: order_price = price_area.open elif order_price == PriceType.NextOpen: raise ValueError("Next open must be handled higher in the stack") elif order_price == PriceType.NextClose: raise ValueError("Next close must be handled higher in the stack") # Ignore order if size or price is nan if np.isnan(order.size): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeNaN), exec_state if np.isnan(order_price): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.PriceNaN), exec_state # Check account state if np.isnan(cash): raise ValueError("exec_state.cash cannot be NaN") if not np.isfinite(position): raise ValueError("exec_state.position must be finite") if not np.isfinite(debt) or debt < 0: raise ValueError("exec_state.debt must be finite and 0 or greater") if not np.isfinite(locked_cash) or locked_cash < 0: raise ValueError("exec_state.locked_cash must be finite and 0 or greater") if np.isnan(free_cash): raise ValueError("exec_state.free_cash cannot be NaN") # Check order if not np.isfinite(order_price) or order_price < 0: raise ValueError("order.price must be finite and 0 or greater") if order.size_type < 0 or order.size_type >= len(SizeType): raise ValueError("order.size_type is invalid") if order.direction < 0 or order.direction >= len(Direction): raise ValueError("order.direction is invalid") if not np.isfinite(order.fees): raise ValueError("order.fees must be finite") if not np.isfinite(order.fixed_fees): raise ValueError("order.fixed_fees must be finite") if not np.isfinite(order.slippage) or order.slippage < 0: raise ValueError("order.slippage must be finite and 0 or greater") if np.isinf(order.min_size) or order.min_size < 0: raise ValueError("order.min_size must be either NaN, 0, or greater") if order.max_size <= 0: raise ValueError("order.max_size must be either NaN or greater than 0") if np.isinf(order.size_granularity) or order.size_granularity <= 0: raise ValueError("order.size_granularity must be either NaN, or finite and greater than 0") if np.isnan(order.leverage) or order.leverage <= 0: raise ValueError("order.leverage must be greater than 0") if order.leverage_mode < 0 or order.leverage_mode >= len(LeverageMode): raise ValueError("order.leverage_mode is invalid") if not np.isfinite(order.reject_prob) or order.reject_prob < 0 or order.reject_prob > 1: raise ValueError("order.reject_prob must be between 0 and 1") # Positive/negative size in short direction should be treated as negative/positive order_size = get_diraware_size_nb(order.size, order.direction) min_order_size = order.min_size max_order_size = order.max_size order_size_type = order.size_type if ( order_size_type == SizeType.ValuePercent100 or order_size_type == SizeType.ValuePercent or order_size_type == SizeType.TargetPercent100 or order_size_type == SizeType.TargetPercent or order_size_type == SizeType.Value or order_size_type == SizeType.TargetValue ): if np.isinf(val_price) or val_price <= 0: raise ValueError("val_price_now must be finite and greater than 0") if np.isnan(val_price): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.ValPriceNaN), exec_state if ( order_size_type == SizeType.ValuePercent100 or order_size_type == SizeType.ValuePercent or order_size_type == SizeType.TargetPercent100 or order_size_type == SizeType.TargetPercent ): if np.isnan(value): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.ValueNaN), exec_state if value <= 0: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.ValueZeroNeg), exec_state order_size, percent = resolve_size_nb( size=order_size, size_type=order_size_type, position=position, val_price=val_price, value=value, ) if not np.isnan(min_order_size): min_order_size, min_percent = resolve_size_nb( size=min_order_size, size_type=order_size_type, position=position, val_price=val_price, value=value, as_requirement=True, ) if not np.isnan(percent) and not np.isnan(min_percent) and is_less_nb(percent, min_percent): return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), exec_state if not np.isnan(max_order_size): max_order_size, max_percent = resolve_size_nb( size=max_order_size, size_type=order_size_type, position=position, val_price=val_price, value=value, as_requirement=True, ) if not np.isnan(percent) and not np.isnan(max_percent) and is_less_nb(max_percent, percent): percent = max_percent if order_size >= 0: order_result, new_account_state = buy_nb( account_state=account_state, size=order_size, price=order_price, direction=order.direction, fees=order.fees, fixed_fees=order.fixed_fees, slippage=order.slippage, min_size=min_order_size, max_size=max_order_size, size_granularity=order.size_granularity, leverage=order.leverage, leverage_mode=order.leverage_mode, price_area_vio_mode=order.price_area_vio_mode, allow_partial=order.allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) else: order_result, new_account_state = sell_nb( account_state=account_state, size=-order_size, price=order_price, direction=order.direction, fees=order.fees, fixed_fees=order.fixed_fees, slippage=order.slippage, min_size=min_order_size, max_size=max_order_size, size_granularity=order.size_granularity, leverage=order.leverage, price_area_vio_mode=order.price_area_vio_mode, allow_partial=order.allow_partial, percent=percent, price_area=price_area, is_closing_price=is_closing_price, ) if order.reject_prob > 0: if np.random.uniform(0, 1) < order.reject_prob: return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.RandomEvent), exec_state if order_result.status == OrderStatus.Rejected and order.raise_reject: raise_rejected_order_nb(order_result) is_filled = order_result.status == OrderStatus.Filled if is_filled and update_value: new_val_price = order_result.price new_value = update_value_nb( cash, new_account_state.cash, position, new_account_state.position, val_price, order_result.price, value, ) else: new_val_price = val_price new_value = value new_exec_state = ExecState( cash=new_account_state.cash, position=new_account_state.position, debt=new_account_state.debt, locked_cash=new_account_state.locked_cash, free_cash=new_account_state.free_cash, val_price=new_val_price, value=new_value, ) return order_result, new_exec_state @register_jitted(cache=True) def fill_log_record_nb( records: tp.RecordArray2d, r: int, group: int, col: int, i: int, price_area: PriceArea, exec_state: ExecState, order: Order, order_result: OrderResult, new_exec_state: ExecState, order_id: int, ) -> None: """Fill a log record.""" records["id"][r, col] = r records["group"][r, col] = group records["col"][r, col] = col records["idx"][r, col] = i records["price_area_open"][r, col] = price_area.open records["price_area_high"][r, col] = price_area.high records["price_area_low"][r, col] = price_area.low records["price_area_close"][r, col] = price_area.close records["st0_cash"][r, col] = exec_state.cash records["st0_position"][r, col] = exec_state.position records["st0_debt"][r, col] = exec_state.debt records["st0_locked_cash"][r, col] = exec_state.locked_cash records["st0_free_cash"][r, col] = exec_state.free_cash records["st0_val_price"][r, col] = exec_state.val_price records["st0_value"][r, col] = exec_state.value records["req_size"][r, col] = order.size records["req_price"][r, col] = order.price records["req_size_type"][r, col] = order.size_type records["req_direction"][r, col] = order.direction records["req_fees"][r, col] = order.fees records["req_fixed_fees"][r, col] = order.fixed_fees records["req_slippage"][r, col] = order.slippage records["req_min_size"][r, col] = order.min_size records["req_max_size"][r, col] = order.max_size records["req_size_granularity"][r, col] = order.size_granularity records["req_leverage"][r, col] = order.leverage records["req_leverage_mode"][r, col] = order.leverage_mode records["req_reject_prob"][r, col] = order.reject_prob records["req_price_area_vio_mode"][r, col] = order.price_area_vio_mode records["req_allow_partial"][r, col] = order.allow_partial records["req_raise_reject"][r, col] = order.raise_reject records["req_log"][r, col] = order.log records["res_size"][r, col] = order_result.size records["res_price"][r, col] = order_result.price records["res_fees"][r, col] = order_result.fees records["res_side"][r, col] = order_result.side records["res_status"][r, col] = order_result.status records["res_status_info"][r, col] = order_result.status_info records["st1_cash"][r, col] = new_exec_state.cash records["st1_position"][r, col] = new_exec_state.position records["st1_debt"][r, col] = new_exec_state.debt records["st1_locked_cash"][r, col] = new_exec_state.locked_cash records["st1_free_cash"][r, col] = new_exec_state.free_cash records["st1_val_price"][r, col] = new_exec_state.val_price records["st1_value"][r, col] = new_exec_state.value records["order_id"][r, col] = order_id @register_jitted(cache=True) def fill_order_record_nb(records: tp.RecordArray2d, r: int, col: int, i: int, order_result: OrderResult) -> None: """Fill an order record.""" records["id"][r, col] = r records["col"][r, col] = col records["idx"][r, col] = i records["size"][r, col] = order_result.size records["price"][r, col] = order_result.price records["fees"][r, col] = order_result.fees records["side"][r, col] = order_result.side @register_jitted(cache=True) def raise_rejected_order_nb(order_result: OrderResult) -> None: """Raise an `vectorbtpro.portfolio.enums.RejectedOrderError`.""" if order_result.status_info == OrderStatusInfo.SizeNaN: raise RejectedOrderError("Size is NaN") if order_result.status_info == OrderStatusInfo.PriceNaN: raise RejectedOrderError("Price is NaN") if order_result.status_info == OrderStatusInfo.ValPriceNaN: raise RejectedOrderError("Asset valuation price is NaN") if order_result.status_info == OrderStatusInfo.ValueNaN: raise RejectedOrderError("Asset/group value is NaN") if order_result.status_info == OrderStatusInfo.ValueZeroNeg: raise RejectedOrderError("Asset/group value is zero or negative") if order_result.status_info == OrderStatusInfo.SizeZero: raise RejectedOrderError("Size is zero") if order_result.status_info == OrderStatusInfo.NoCash: raise RejectedOrderError("Not enough cash") if order_result.status_info == OrderStatusInfo.NoOpenPosition: raise RejectedOrderError("No open position to reduce/close") if order_result.status_info == OrderStatusInfo.MaxSizeExceeded: raise RejectedOrderError("Size is greater than maximum allowed") if order_result.status_info == OrderStatusInfo.RandomEvent: raise RejectedOrderError("Random event happened") if order_result.status_info == OrderStatusInfo.CantCoverFees: raise RejectedOrderError("Not enough cash to cover fees") if order_result.status_info == OrderStatusInfo.MinSizeNotReached: raise RejectedOrderError("Final size is less than minimum allowed") if order_result.status_info == OrderStatusInfo.PartialFill: raise RejectedOrderError("Final size is less than requested") raise RejectedOrderError @register_jitted(cache=True) def process_order_nb( group: int, col: int, i: int, exec_state: ExecState, order: Order, price_area: PriceArea = NoPriceArea, update_value: bool = False, order_records: tp.Optional[tp.RecordArray2d] = None, order_counts: tp.Optional[tp.Array1d] = None, log_records: tp.Optional[tp.RecordArray2d] = None, log_counts: tp.Optional[tp.Array1d] = None, ) -> tp.Tuple[OrderResult, ExecState]: """Process an order by executing it, saving relevant information to the logs, and returning a new state.""" # Execute the order order_result, new_exec_state = execute_order_nb( exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, ) is_filled = order_result.status == OrderStatus.Filled if order_records is not None and order_counts is not None: if is_filled and order_records.shape[0] > 0: # Fill order record if order_counts[col] >= order_records.shape[0]: raise IndexError("order_records index out of range. Set a higher max_order_records.") fill_order_record_nb(order_records, order_counts[col], col, i, order_result) order_counts[col] += 1 if log_records is not None and log_counts is not None: if order.log and log_records.shape[0] > 0: # Fill log record if log_counts[col] >= log_records.shape[0]: raise IndexError("log_records index out of range. Set a higher max_log_records.") fill_log_record_nb( log_records, log_counts[col], group, col, i, price_area, exec_state, order, order_result, new_exec_state, order_counts[col] - 1 if order_counts is not None and is_filled else -1, ) log_counts[col] += 1 return order_result, new_exec_state @register_jitted(cache=True) def order_nb( size: float = np.inf, price: float = np.inf, size_type: int = SizeType.Amount, direction: int = Direction.Both, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, leverage: float = 1.0, leverage_mode: int = LeverageMode.Lazy, reject_prob: float = 0.0, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, raise_reject: bool = False, log: bool = False, ) -> Order: """Create an order. See `vectorbtpro.portfolio.enums.Order` for details on arguments.""" return Order( size=float(size), price=float(price), size_type=int(size_type), direction=int(direction), fees=float(fees), fixed_fees=float(fixed_fees), slippage=float(slippage), min_size=float(min_size), max_size=float(max_size), size_granularity=float(size_granularity), leverage=float(leverage), leverage_mode=int(leverage_mode), reject_prob=float(reject_prob), price_area_vio_mode=int(price_area_vio_mode), allow_partial=bool(allow_partial), raise_reject=bool(raise_reject), log=bool(log), ) @register_jitted(cache=True) def close_position_nb( price: float = np.inf, fees: float = 0.0, fixed_fees: float = 0.0, slippage: float = 0.0, min_size: float = np.nan, max_size: float = np.nan, size_granularity: float = np.nan, leverage: float = 1.0, leverage_mode: int = LeverageMode.Lazy, reject_prob: float = 0.0, price_area_vio_mode: int = PriceAreaVioMode.Ignore, allow_partial: bool = True, raise_reject: bool = False, log: bool = False, ) -> Order: """Close the current position.""" return order_nb( size=0.0, price=price, size_type=SizeType.TargetAmount, direction=Direction.Both, fees=fees, fixed_fees=fixed_fees, slippage=slippage, min_size=min_size, max_size=max_size, size_granularity=size_granularity, leverage=leverage, leverage_mode=leverage_mode, reject_prob=reject_prob, price_area_vio_mode=price_area_vio_mode, allow_partial=allow_partial, raise_reject=raise_reject, log=log, ) @register_jitted(cache=True) def order_nothing_nb() -> Order: """Convenience function to order nothing.""" return NoOrder @register_jitted(cache=True) def check_group_lens_nb(group_lens: tp.GroupLens, n_cols: int) -> None: """Check `group_lens`.""" if np.sum(group_lens) != n_cols: raise ValueError("group_lens has incorrect total number of columns") @register_jitted(cache=True) def is_grouped_nb(group_lens: tp.GroupLens) -> bool: """Check if columm,ns are grouped, that is, more than one column per group.""" return np.any(group_lens > 1) @register_jitted(cache=True) def prepare_records_nb( target_shape: tp.Shape, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, ) -> tp.Tuple[tp.RecordArray2d, tp.RecordArray2d]: """Prepare records.""" if max_order_records is None: order_records = np.empty((target_shape[0], target_shape[1]), dtype=order_dt) else: order_records = np.empty((max_order_records, target_shape[1]), dtype=order_dt) if max_log_records is None: log_records = np.empty((target_shape[0], target_shape[1]), dtype=log_dt) else: log_records = np.empty((max_log_records, target_shape[1]), dtype=log_dt) return order_records, log_records @register_jitted(cache=True) def prepare_last_cash_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, init_cash: tp.FlexArray1d, ) -> tp.Array1d: """Prepare `last_cash`.""" if cash_sharing: last_cash = np.empty(len(group_lens), dtype=float_) for group in range(len(group_lens)): last_cash[group] = float(flex_select_1d_pc_nb(init_cash, group)) else: last_cash = np.empty(target_shape[1], dtype=float_) for col in range(target_shape[1]): last_cash[col] = float(flex_select_1d_pc_nb(init_cash, col)) return last_cash @register_jitted(cache=True) def prepare_last_position_nb(target_shape: tp.Shape, init_position: tp.FlexArray1d) -> tp.Array1d: """Prepare `last_position`.""" last_position = np.empty(target_shape[1], dtype=float_) for col in range(target_shape[1]): last_position[col] = float(flex_select_1d_pc_nb(init_position, col)) return last_position @register_jitted(cache=True) def prepare_last_value_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, init_cash: tp.FlexArray1d, init_position: tp.FlexArray1d, init_price: tp.FlexArray1d, ) -> tp.Array1d: """Prepare `last_value`.""" if cash_sharing: last_value = np.empty(len(group_lens), dtype=float_) from_col = 0 for group in range(len(group_lens)): to_col = from_col + group_lens[group] _init_cash = float(flex_select_1d_pc_nb(init_cash, group)) last_value[group] = _init_cash for col in range(from_col, to_col): _init_position = float(flex_select_1d_pc_nb(init_position, col)) _init_price = float(flex_select_1d_pc_nb(init_price, col)) if _init_position != 0: last_value[group] += _init_position * _init_price from_col = to_col else: last_value = np.empty(target_shape[1], dtype=float_) for col in range(target_shape[1]): _init_cash = float(flex_select_1d_pc_nb(init_cash, col)) _init_position = float(flex_select_1d_pc_nb(init_position, col)) _init_price = float(flex_select_1d_pc_nb(init_price, col)) if _init_position == 0: last_value[col] = _init_cash else: last_value[col] = _init_cash + _init_position * _init_price return last_value @register_jitted(cache=True) def prepare_last_pos_info_nb( target_shape: tp.Shape, init_position: tp.FlexArray1d, init_price: tp.FlexArray1d, fill_pos_info: bool = True, ) -> tp.RecordArray: """Prepare `last_pos_info`.""" if fill_pos_info: last_pos_info = np.empty(target_shape[1], dtype=trade_dt) last_pos_info["id"][:] = -1 last_pos_info["col"][:] = -1 last_pos_info["size"][:] = np.nan last_pos_info["entry_order_id"][:] = -1 last_pos_info["entry_idx"][:] = -1 last_pos_info["entry_price"][:] = np.nan last_pos_info["entry_fees"][:] = np.nan last_pos_info["exit_order_id"][:] = -1 last_pos_info["exit_idx"][:] = -1 last_pos_info["exit_price"][:] = np.nan last_pos_info["exit_fees"][:] = np.nan last_pos_info["pnl"][:] = np.nan last_pos_info["return"][:] = np.nan last_pos_info["direction"][:] = -1 last_pos_info["status"][:] = -1 last_pos_info["parent_id"][:] = -1 for col in range(target_shape[1]): _init_position = float(flex_select_1d_pc_nb(init_position, col)) _init_price = float(flex_select_1d_pc_nb(init_price, col)) if _init_position != 0: fill_init_pos_info_nb(last_pos_info[col], col, _init_position, _init_price) else: last_pos_info = np.empty(0, dtype=trade_dt) return last_pos_info @register_jitted def prepare_sim_out_nb( order_records: tp.RecordArray2d, order_counts: tp.Array1d, log_records: tp.RecordArray2d, log_counts: tp.Array1d, cash_deposits: tp.Array2d, cash_earnings: tp.Array2d, call_seq: tp.Optional[tp.Array2d] = None, in_outputs: tp.Optional[tp.NamedTuple] = None, sim_start: tp.Optional[tp.Array1d] = None, sim_end: tp.Optional[tp.Array1d] = None, ) -> SimulationOutput: """Prepare simulation output.""" order_records_flat = generic_nb.repartition_nb(order_records, order_counts) log_records_flat = generic_nb.repartition_nb(log_records, log_counts) return SimulationOutput( order_records=order_records_flat, log_records=log_records_flat, cash_deposits=cash_deposits, cash_earnings=cash_earnings, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start, sim_end=sim_end, ) @register_jitted(cache=True) def get_trade_stats_nb( size: float, entry_price: float, entry_fees: float, exit_price: float, exit_fees: float, direction: int, ) -> tp.Tuple[float, float]: """Get trade statistics.""" entry_val = size * entry_price exit_val = size * exit_price val_diff = add_nb(exit_val, -entry_val) if val_diff != 0 and direction == TradeDirection.Short: val_diff *= -1 pnl = val_diff - entry_fees - exit_fees if is_close_nb(entry_val, 0): ret = np.nan else: ret = pnl / entry_val return pnl, ret @register_jitted(cache=True) def update_open_pos_info_stats_nb(record: tp.Record, position_now: float, price: float) -> None: """Update statistics of an open position record using custom price.""" if record["id"] >= 0 and record["status"] == TradeStatus.Open: if np.isnan(record["exit_price"]): exit_price = price else: exit_size_sum = record["size"] - abs(position_now) exit_gross_sum = exit_size_sum * record["exit_price"] exit_gross_sum += abs(position_now) * price exit_price = exit_gross_sum / record["size"] pnl, ret = get_trade_stats_nb( record["size"], record["entry_price"], record["entry_fees"], exit_price, record["exit_fees"], record["direction"], ) record["pnl"] = pnl record["return"] = ret @register_jitted(cache=True) def fill_init_pos_info_nb(record: tp.Record, col: int, position_now: float, price: float) -> None: """Fill position record for an initial position.""" record["id"] = 0 record["col"] = col record["size"] = abs(position_now) record["entry_order_id"] = -1 record["entry_idx"] = -1 record["entry_price"] = price record["entry_fees"] = 0.0 record["exit_order_id"] = -1 record["exit_idx"] = -1 record["exit_price"] = np.nan record["exit_fees"] = 0.0 if position_now >= 0: record["direction"] = TradeDirection.Long else: record["direction"] = TradeDirection.Short record["status"] = TradeStatus.Open record["parent_id"] = record["id"] # Update open position stats update_open_pos_info_stats_nb(record, position_now, np.nan) @register_jitted(cache=True) def update_pos_info_nb( record: tp.Record, i: int, col: int, position_before: float, position_now: float, order_result: OrderResult, order_id: int, ) -> None: """Update position record after filling an order.""" if order_result.status == OrderStatus.Filled: if position_before == 0 and position_now != 0: # New position opened record["id"] += 1 record["col"] = col record["size"] = order_result.size record["entry_order_id"] = order_id record["entry_idx"] = i record["entry_price"] = order_result.price record["entry_fees"] = order_result.fees record["exit_order_id"] = -1 record["exit_idx"] = -1 record["exit_price"] = np.nan record["exit_fees"] = 0.0 if order_result.side == OrderSide.Buy: record["direction"] = TradeDirection.Long else: record["direction"] = TradeDirection.Short record["status"] = TradeStatus.Open record["parent_id"] = record["id"] elif position_before != 0 and position_now == 0: # Position closed record["exit_order_id"] = order_id record["exit_idx"] = i if np.isnan(record["exit_price"]): exit_price = order_result.price else: exit_size_sum = record["size"] - abs(position_before) exit_gross_sum = exit_size_sum * record["exit_price"] exit_gross_sum += abs(position_before) * order_result.price exit_price = exit_gross_sum / record["size"] record["exit_price"] = exit_price record["exit_fees"] += order_result.fees pnl, ret = get_trade_stats_nb( record["size"], record["entry_price"], record["entry_fees"], record["exit_price"], record["exit_fees"], record["direction"], ) record["pnl"] = pnl record["return"] = ret record["status"] = TradeStatus.Closed elif np.sign(position_before) != np.sign(position_now): # Position reversed record["id"] += 1 record["size"] = abs(position_now) record["entry_order_id"] = order_id record["entry_idx"] = i record["entry_price"] = order_result.price new_pos_fraction = abs(position_now) / abs(position_now - position_before) record["entry_fees"] = new_pos_fraction * order_result.fees record["exit_order_id"] = -1 record["exit_idx"] = -1 record["exit_price"] = np.nan record["exit_fees"] = 0.0 if order_result.side == OrderSide.Buy: record["direction"] = TradeDirection.Long else: record["direction"] = TradeDirection.Short record["status"] = TradeStatus.Open record["parent_id"] = record["id"] else: # Position changed if abs(position_before) <= abs(position_now): # Position increased entry_gross_sum = record["size"] * record["entry_price"] entry_gross_sum += order_result.size * order_result.price entry_price = entry_gross_sum / (record["size"] + order_result.size) record["entry_price"] = entry_price record["entry_fees"] += order_result.fees record["size"] += order_result.size else: # Position decreased record["exit_order_id"] = order_id if np.isnan(record["exit_price"]): exit_price = order_result.price else: exit_size_sum = record["size"] - abs(position_before) exit_gross_sum = exit_size_sum * record["exit_price"] exit_gross_sum += order_result.size * order_result.price exit_price = exit_gross_sum / (exit_size_sum + order_result.size) record["exit_price"] = exit_price record["exit_fees"] += order_result.fees # Update open position stats update_open_pos_info_stats_nb(record, position_now, order_result.price) @register_jitted(cache=True) def resolve_hl_nb(open, high, low, close): """Resolve the current high and low.""" if np.isnan(high): if np.isnan(open): high = close elif np.isnan(close): high = open else: high = max(open, close) if np.isnan(low): if np.isnan(open): low = close elif np.isnan(close): low = open else: low = min(open, close) return high, low @register_jitted(cache=True) def check_price_hit_nb( open: float, high: float, low: float, close: float, price: float, hit_below: bool = True, can_use_ohlc: bool = True, check_open: bool = True, hard_price: bool = False, ) -> tp.Tuple[float, bool, bool]: """Check whether a target price was hit. If `hard_price` is False, and `can_use_ohlc` and `check_open` are True and the target price is hit by open, returns open. Otherwise, returns the actual target price. Returns the stop price, whether it was hit by open, and whether it was hit during this bar.""" high, low = resolve_hl_nb( open=open, high=high, low=low, close=close, ) if hit_below: if can_use_ohlc and check_open and is_close_or_less_nb(open, price): if hard_price: return price, True, True return open, True, True if is_close_or_less_nb(close, price) or (can_use_ohlc and is_close_or_less_nb(low, price)): return price, False, True return price, False, False if can_use_ohlc and check_open and is_close_or_greater_nb(open, price): if hard_price: return price, True, True return open, True, True if is_close_or_greater_nb(close, price) or (can_use_ohlc and is_close_or_greater_nb(high, price)): return price, False, True return price, False, False @register_jitted(cache=True) def resolve_stop_exit_price_nb( stop_price: float, close: float, stop_exit_price: float, ) -> float: """Resolve the exit price of a stop order.""" if stop_exit_price == StopExitPrice.Stop or stop_exit_price == StopExitPrice.HardStop: return float(stop_price) elif stop_exit_price == StopExitPrice.Close: return float(close) elif stop_exit_price < 0: raise ValueError("Invalid StopExitPrice option") return float(stop_exit_price) @register_jitted(cache=True) def is_limit_active_nb(init_idx: int, init_price: float) -> bool: """Check whether a limit order is active.""" return init_idx != -1 and not np.isnan(init_price) @register_jitted(cache=True) def is_stop_active_nb(init_idx: int, stop: float) -> bool: """Check whether a stop order is active.""" return init_idx != -1 and not np.isnan(stop) @register_jitted(cache=True) def is_time_stop_active_nb(init_idx: int, stop: int) -> bool: """Check whether a time stop order is active.""" return init_idx != -1 and stop != -1 @register_jitted(cache=True) def should_update_stop_nb(new_stop: float, upon_stop_update: int) -> bool: """Whether to update stop.""" if upon_stop_update == StopUpdateMode.Keep: return False if upon_stop_update == StopUpdateMode.Override or upon_stop_update == StopUpdateMode.OverrideNaN: if not np.isnan(new_stop) or upon_stop_update == StopUpdateMode.OverrideNaN: return True return False raise ValueError("Invalid StopUpdateMode option") @register_jitted(cache=True) def should_update_time_stop_nb(new_stop: int, upon_stop_update: int) -> bool: """Whether to update time stop.""" if upon_stop_update == StopUpdateMode.Keep: return False if upon_stop_update == StopUpdateMode.Override or upon_stop_update == StopUpdateMode.OverrideNaN: if new_stop != -1 or upon_stop_update == StopUpdateMode.OverrideNaN: return True return False raise ValueError("Invalid StopUpdateMode option") @register_jitted(cache=True) def check_limit_expired_nb( creation_idx: int, i: int, tif: int = -1, expiry: int = -1, time_delta_format: int = TimeDeltaFormat.Index, index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, ) -> tp.Tuple[bool, bool]: """Check whether limit is expired by comparing the current index with the creation index. Returns whether the limit expires already on open, and whether the limit expires during this bar.""" if tif == -1 and expiry == -1: return False, False if time_delta_format == TimeDeltaFormat.Rows: is_expired_on_open = False is_expired = False if tif != -1: if creation_idx + tif <= i: is_expired_on_open = True is_expired = True elif i < creation_idx + tif < i + 1: is_expired = True if expiry != -1: if expiry <= i: is_expired_on_open = True is_expired = True elif i < expiry < i + 1: is_expired = True return is_expired_on_open, is_expired elif time_delta_format == TimeDeltaFormat.Index: if index is None: raise ValueError("Must provide index for TimeDeltaFormat.Index") if freq is None: raise ValueError("Must provide frequency for TimeDeltaFormat.Index") is_expired_on_open = False is_expired = False if tif != -1: if index[creation_idx] + tif <= index[i]: is_expired_on_open = True is_expired = True elif index[i] < index[creation_idx] + tif < index[i] + freq: is_expired = True if expiry != -1: if expiry <= index[i]: is_expired_on_open = True is_expired = True elif index[i] < expiry < index[i] + freq: is_expired = True return is_expired_on_open, is_expired else: raise ValueError("Invalid TimeDeltaFormat option") @register_jitted(cache=True) def resolve_limit_price_nb( init_price: float, limit_delta: float = np.nan, delta_format: int = DeltaFormat.Percent, hit_below: bool = True, ) -> float: """Resolve the limit price.""" if delta_format == DeltaFormat.Percent100: limit_delta /= 100 delta_format = DeltaFormat.Percent if not np.isnan(limit_delta): if hit_below: if np.isinf(limit_delta) and delta_format != DeltaFormat.Target: if limit_delta > 0: limit_price = -np.inf else: limit_price = np.inf else: if delta_format == DeltaFormat.Absolute: limit_price = init_price - limit_delta elif delta_format == DeltaFormat.Percent: limit_price = init_price * (1 - limit_delta) elif delta_format == DeltaFormat.Target: limit_price = limit_delta else: raise ValueError("Invalid DeltaFormat option") else: if np.isinf(limit_delta) and delta_format != DeltaFormat.Target: if limit_delta < 0: limit_price = -np.inf else: limit_price = np.inf else: if delta_format == DeltaFormat.Absolute: limit_price = init_price + limit_delta elif delta_format == DeltaFormat.Percent: limit_price = init_price * (1 + limit_delta) elif delta_format == DeltaFormat.Target: limit_price = limit_delta else: raise ValueError("Invalid DeltaFormat option") else: limit_price = init_price return limit_price @register_jitted(cache=True) def check_limit_hit_nb( open: float, high: float, low: float, close: float, price: float, size: float, direction: int = Direction.Both, limit_delta: float = np.nan, delta_format: int = DeltaFormat.Percent, limit_reverse: bool = False, can_use_ohlc: bool = True, check_open: bool = True, hard_limit: bool = False, ) -> tp.Tuple[float, bool, bool]: """Resolve the limit price using `resolve_limit_price_nb` and check whether it was hit. Returns the limit price, whether it was hit before open, and whether it was hit during this bar. If `can_use_ohlc` and `check_open` is True and the stop is hit before open, returns open.""" if size == 0: raise ValueError("Limit order size cannot be zero") _size = get_diraware_size_nb(size, direction) hit_below = (_size > 0 and not limit_reverse) or (_size < 0 and limit_reverse) limit_price = resolve_limit_price_nb( init_price=price, limit_delta=limit_delta, delta_format=delta_format, hit_below=hit_below, ) hit_on_open = False if can_use_ohlc: high, low = resolve_hl_nb( open=open, high=high, low=low, close=close, ) if hit_below: if check_open and is_close_or_less_nb(open, limit_price): hit_on_open = True hit = True if not hard_limit: limit_price = open else: hit = is_close_or_less_nb(low, limit_price) if hit and np.isinf(limit_price): limit_price = low else: if check_open and is_close_or_greater_nb(open, limit_price): hit_on_open = True hit = True if not hard_limit: limit_price = open else: hit = is_close_or_greater_nb(high, limit_price) if hit and np.isinf(limit_price): limit_price = high else: if hit_below: hit = is_close_or_less_nb(close, limit_price) else: hit = is_close_or_greater_nb(close, limit_price) if hit and np.isinf(limit_price): limit_price = close return limit_price, hit_on_open, hit @register_jitted(cache=True) def resolve_limit_order_price_nb( limit_price: float, close: float, limit_order_price: float, ) -> float: """Resolve the limit order price of a limit order.""" if limit_order_price == LimitOrderPrice.Limit or limit_order_price == LimitOrderPrice.HardLimit: return float(limit_price) elif limit_order_price == LimitOrderPrice.Close: return float(close) elif limit_order_price < 0: raise ValueError("Invalid LimitOrderPrice option") return float(limit_order_price) @register_jitted(cache=True) def resolve_stop_price_nb( init_price: float, stop: float, delta_format: int = DeltaFormat.Percent, hit_below: bool = True, ) -> float: """Resolve the stop price.""" if delta_format == DeltaFormat.Percent100: stop /= 100 delta_format = DeltaFormat.Percent if hit_below: if delta_format == DeltaFormat.Absolute: stop_price = init_price - abs(stop) elif delta_format == DeltaFormat.Percent: stop_price = init_price * (1 - abs(stop)) elif delta_format == DeltaFormat.Target: stop_price = stop else: raise ValueError("Invalid DeltaFormat option") else: if delta_format == DeltaFormat.Absolute: stop_price = init_price + abs(stop) elif delta_format == DeltaFormat.Percent: stop_price = init_price * (1 + abs(stop)) elif delta_format == DeltaFormat.Target: stop_price = stop else: raise ValueError("Invalid DeltaFormat option") return stop_price @register_jitted(cache=True) def check_stop_hit_nb( open: float, high: float, low: float, close: float, is_position_long: bool, init_price: float, stop: float, delta_format: int = DeltaFormat.Percent, hit_below: bool = True, can_use_ohlc: bool = True, check_open: bool = True, hard_stop: bool = False, ) -> tp.Tuple[float, bool, bool]: """Resolve the stop price using `resolve_stop_price_nb` and check whether it was hit. See `check_price_hit_nb`.""" hit_below = (is_position_long and hit_below) or (not is_position_long and not hit_below) stop_price = resolve_stop_price_nb( init_price=init_price, stop=stop, delta_format=delta_format, hit_below=hit_below, ) return check_price_hit_nb( open=open, high=high, low=low, close=close, price=stop_price, hit_below=hit_below, can_use_ohlc=can_use_ohlc, check_open=check_open, hard_price=hard_stop, ) @register_jitted(cache=True) def check_td_stop_hit_nb( init_idx: int, i: int, stop: int = -1, time_delta_format: int = TimeDeltaFormat.Index, index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, ) -> tp.Tuple[bool, bool]: """Check whether TD stop was hit by comparing the current index with the initial index. Returns whether the stop was hit already on open, and whether the stop was hit during this bar.""" if stop == -1: return False, False if time_delta_format == TimeDeltaFormat.Rows: is_hit_on_open = False is_hit = False if stop != -1: if init_idx + stop <= i: is_hit_on_open = True is_hit = True elif i < init_idx + stop < i + 1: is_hit = True return is_hit_on_open, is_hit elif time_delta_format == TimeDeltaFormat.Index: if index is None: raise ValueError("Must provide index for TimeDeltaFormat.Index") if freq is None: raise ValueError("Must provide frequency for TimeDeltaFormat.Index") is_hit_on_open = False is_hit = False if stop != -1: if index[init_idx] + stop <= index[i]: is_hit_on_open = True is_hit = True elif index[i] < index[init_idx] + stop < index[i] + freq: is_hit = True return is_hit_on_open, is_hit else: raise ValueError("Invalid TimeDeltaFormat option") @register_jitted(cache=True) def check_dt_stop_hit_nb( i: int, stop: int = -1, time_delta_format: int = TimeDeltaFormat.Index, index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, ) -> tp.Tuple[bool, bool]: """Check whether DT stop was hit by comparing the current index with the initial index. Returns whether the stop was hit already on open, and whether the stop was hit during this bar.""" if stop == -1: return False, False if time_delta_format == TimeDeltaFormat.Rows: is_hit_on_open = False is_hit = False if stop != -1: if stop <= i: is_hit_on_open = True is_hit = True elif i < stop < i + 1: is_hit = True return is_hit_on_open, is_hit elif time_delta_format == TimeDeltaFormat.Index: if index is None: raise ValueError("Must provide index for TimeDeltaFormat.Index") if freq is None: raise ValueError("Must provide frequency for TimeDeltaFormat.Index") is_hit_on_open = False is_hit = False if stop != -1: if stop <= index[i]: is_hit_on_open = True is_hit = True elif index[i] < stop < index[i] + freq: is_hit = True return is_hit_on_open, is_hit else: raise ValueError("Invalid TimeDeltaFormat option") @register_jitted(cache=True) def check_tsl_th_hit_nb( is_position_long: bool, init_price: float, peak_price: float, threshold: float, delta_format: int = DeltaFormat.Percent, ) -> bool: """Resolve the TSL threshold price using `resolve_stop_price_nb` and check whether it was hit.""" hit_below = not is_position_long tsl_th_price = resolve_stop_price_nb( init_price=init_price, stop=threshold, delta_format=delta_format, hit_below=hit_below, ) if hit_below: return is_close_or_less_nb(peak_price, tsl_th_price) return is_close_or_greater_nb(peak_price, tsl_th_price) @register_jitted(cache=True) def resolve_dyn_limit_price_nb(val_price: float, price: float, limit_price: float) -> float: """Resolve price dynamically. Uses the valuation price as the left bound and order price as the right bound.""" if np.isinf(limit_price): if limit_price < 0: return float(val_price) return float(price) return float(limit_price) @register_jitted(cache=True) def resolve_dyn_stop_entry_price_nb(val_price: float, price: float, stop_entry_price: float) -> float: """Resolve stop entry price dynamically. Uses the valuation/open price as the left bound and order price as the right bound.""" if np.isinf(stop_entry_price): if stop_entry_price < 0: return float(val_price) return float(price) if stop_entry_price < 0: if stop_entry_price == StopEntryPrice.ValPrice: return float(val_price) if stop_entry_price == StopEntryPrice.Price: return float(price) raise ValueError("Only valuation and order price are supported when setting stop entry price dynamically") return float(stop_entry_price) @register_jitted(cache=True) def get_stop_ladder_exit_size_nb( stop_: tp.FlexArray2d, step: int, col: int, init_price: float, init_position: float, position_now: float, ladder: int = StopLadderMode.Disabled, delta_format: int = DeltaFormat.Percent, hit_below: bool = True, ) -> float: """Get the exit size corresponding to the current step in the ladder.""" if ladder == StopLadderMode.Disabled: raise ValueError("Stop ladder must be enabled to select exit size") if ladder == StopLadderMode.Dynamic: raise ValueError("Stop ladder must be static to select exit size") stop = flex_select_nb(stop_, step, col) if np.isnan(stop): return np.nan last_step = -1 for i in range(step, stop_.shape[0]): if not np.isnan(flex_select_nb(stop_, i, col)): last_step = i else: break if last_step == -1: return np.nan if step == last_step: return abs(position_now) if ladder == StopLadderMode.Uniform: exit_fraction = 1 / (last_step + 1) return exit_fraction * abs(init_position) if ladder == StopLadderMode.AdaptUniform: exit_fraction = 1 / (last_step + 1 - step) return exit_fraction * abs(position_now) hit_below = (init_position >= 0 and hit_below) or (init_position < 0 and not hit_below) price = resolve_stop_price_nb( init_price=init_price, stop=stop, delta_format=delta_format, hit_below=hit_below, ) last_stop = flex_select_nb(stop_, last_step, col) last_price = resolve_stop_price_nb( init_price=init_price, stop=last_stop, delta_format=delta_format, hit_below=hit_below, ) if step == 0: prev_price = init_price else: prev_stop = flex_select_nb(stop_, step - 1, col) prev_price = resolve_stop_price_nb( init_price=init_price, stop=prev_stop, delta_format=delta_format, hit_below=hit_below, ) if ladder == StopLadderMode.Weighted: exit_fraction = (price - prev_price) / (last_price - init_price) return exit_fraction * abs(init_position) if ladder == StopLadderMode.AdaptWeighted: exit_fraction = (price - prev_price) / (last_price - prev_price) return exit_fraction * abs(position_now) raise ValueError("Invalid StopLadderMode option") @register_jitted(cache=True) def get_time_stop_ladder_exit_size_nb( stop_: tp.FlexArray2d, step: int, col: int, init_idx: int, init_position: float, position_now: float, ladder: int = StopLadderMode.Disabled, time_delta_format: int = TimeDeltaFormat.Index, index: tp.Optional[tp.Array1d] = None, ) -> float: """Get the exit size corresponding to the current step in the ladder.""" if ladder == StopLadderMode.Disabled: raise ValueError("Stop ladder must be enabled to select exit size") if ladder == StopLadderMode.Dynamic: raise ValueError("Stop ladder must be static to select exit size") if init_idx == -1: raise ValueError("Initial index of the ladder must be known") if time_delta_format == TimeDeltaFormat.Index: if index is None: raise ValueError("Must provide index for TimeDeltaFormat.Index") init_idx = index[init_idx] idx = flex_select_nb(stop_, step, col) if idx == -1: return np.nan last_step = -1 for i in range(step, stop_.shape[0]): if flex_select_nb(stop_, i, col) != -1: last_step = i else: break if last_step == -1: return np.nan if step == last_step: return abs(position_now) if ladder == StopLadderMode.Uniform: exit_fraction = 1 / (last_step + 1) return exit_fraction * abs(init_position) if ladder == StopLadderMode.AdaptUniform: exit_fraction = 1 / (last_step + 1 - step) return exit_fraction * abs(position_now) last_idx = flex_select_nb(stop_, last_step, col) if step == 0: prev_idx = init_idx else: prev_idx = flex_select_nb(stop_, step - 1, col) if ladder == StopLadderMode.Weighted: exit_fraction = (idx - prev_idx) / (last_idx - init_idx) return exit_fraction * abs(init_position) if ladder == StopLadderMode.AdaptWeighted: exit_fraction = (idx - prev_idx) / (last_idx - prev_idx) return exit_fraction * abs(position_now) raise ValueError("Invalid StopLadderMode option") @register_jitted(cache=True) def is_limit_info_active_nb(limit_info: tp.Record) -> bool: """Check whether information record for a limit order is active.""" return is_limit_active_nb(limit_info["init_idx"], limit_info["init_price"]) @register_jitted(cache=True) def is_stop_info_active_nb(stop_info: tp.Record) -> bool: """Check whether information record for a stop order is active.""" return is_stop_active_nb(stop_info["init_idx"], stop_info["stop"]) @register_jitted(cache=True) def is_time_stop_info_active_nb(time_stop_info: tp.Record) -> bool: """Check whether information record for a time stop order is active.""" return is_time_stop_active_nb(time_stop_info["init_idx"], time_stop_info["stop"]) @register_jitted(cache=True) def is_stop_info_ladder_active_nb(info: tp.Record) -> bool: """Check whether information record for a stop ladder is active.""" return info["step"] != -1 @register_jitted(cache=True) def set_limit_info_nb( limit_info: tp.Record, signal_idx: int, creation_idx: tp.Optional[int] = None, init_idx: tp.Optional[int] = None, init_price: float = -np.inf, init_size: float = np.inf, init_size_type: int = SizeType.Amount, init_direction: int = Direction.Both, init_stop_type: int = -1, delta: float = np.nan, delta_format: int = DeltaFormat.Percent, tif: int = -1, expiry: int = -1, time_delta_format: int = TimeDeltaFormat.Index, reverse: bool = False, order_price: int = LimitOrderPrice.Limit, ) -> None: """Set limit order information. See `vectorbtpro.portfolio.enums.limit_info_dt`.""" limit_info["signal_idx"] = signal_idx limit_info["creation_idx"] = creation_idx if creation_idx is not None else signal_idx limit_info["init_idx"] = init_idx if init_idx is not None else signal_idx limit_info["init_price"] = init_price limit_info["init_size"] = init_size limit_info["init_size_type"] = init_size_type limit_info["init_direction"] = init_direction limit_info["init_stop_type"] = init_stop_type limit_info["delta"] = delta limit_info["delta_format"] = delta_format limit_info["tif"] = tif limit_info["expiry"] = expiry limit_info["time_delta_format"] = time_delta_format limit_info["reverse"] = reverse limit_info["order_price"] = order_price @register_jitted(cache=True) def clear_limit_info_nb(limit_info: tp.Record) -> None: """Clear limit order information.""" limit_info["signal_idx"] = -1 limit_info["creation_idx"] = -1 limit_info["init_idx"] = -1 limit_info["init_price"] = np.nan limit_info["init_size"] = np.nan limit_info["init_size_type"] = -1 limit_info["init_direction"] = -1 limit_info["init_stop_type"] = -1 limit_info["delta"] = np.nan limit_info["delta_format"] = -1 limit_info["tif"] = -1 limit_info["expiry"] = -1 limit_info["time_delta_format"] = -1 limit_info["reverse"] = False limit_info["order_price"] = np.nan @register_jitted(cache=True) def set_sl_info_nb( sl_info: tp.Record, init_idx: int, init_price: float = -np.inf, init_position: float = np.nan, stop: float = np.nan, exit_price: float = StopExitPrice.Stop, exit_size: float = np.nan, exit_size_type: int = -1, exit_type: int = StopExitType.Close, order_type: int = OrderType.Market, limit_delta: float = np.nan, delta_format: int = DeltaFormat.Percent, ladder: int = StopLadderMode.Disabled, step: int = -1, step_idx: int = -1, ) -> None: """Set SL order information. See `vectorbtpro.portfolio.enums.sl_info_dt`.""" sl_info["init_idx"] = init_idx sl_info["init_price"] = init_price sl_info["init_position"] = init_position sl_info["stop"] = stop sl_info["exit_price"] = exit_price sl_info["exit_size"] = exit_size sl_info["exit_size_type"] = exit_size_type sl_info["exit_type"] = exit_type sl_info["order_type"] = order_type sl_info["limit_delta"] = limit_delta sl_info["delta_format"] = delta_format sl_info["ladder"] = ladder sl_info["step"] = step sl_info["step_idx"] = step_idx @register_jitted(cache=True) def clear_sl_info_nb(sl_info: tp.Record) -> None: """Clear SL order information.""" sl_info["init_idx"] = -1 sl_info["init_price"] = np.nan sl_info["init_position"] = np.nan sl_info["stop"] = np.nan sl_info["exit_price"] = -1 sl_info["exit_size"] = np.nan sl_info["exit_size_type"] = -1 sl_info["exit_type"] = -1 sl_info["order_type"] = -1 sl_info["limit_delta"] = np.nan sl_info["delta_format"] = -1 sl_info["ladder"] = -1 sl_info["step"] = -1 sl_info["step_idx"] = -1 @register_jitted(cache=True) def set_tsl_info_nb( tsl_info: tp.Record, init_idx: int, init_price: float = -np.inf, init_position: float = np.nan, peak_idx: tp.Optional[int] = None, peak_price: tp.Optional[float] = None, stop: float = np.nan, th: float = np.nan, exit_price: float = StopExitPrice.Stop, exit_size: float = np.nan, exit_size_type: int = -1, exit_type: int = StopExitType.Close, order_type: int = OrderType.Market, limit_delta: float = np.nan, delta_format: int = DeltaFormat.Percent, ladder: int = StopLadderMode.Disabled, step: int = -1, step_idx: int = -1, ) -> None: """Set TSL/TTP order information. See `vectorbtpro.portfolio.enums.tsl_info_dt`.""" tsl_info["init_idx"] = init_idx tsl_info["init_price"] = init_price tsl_info["init_position"] = init_position tsl_info["peak_idx"] = peak_idx if peak_idx is not None else init_idx tsl_info["peak_price"] = peak_price if peak_price is not None else init_price tsl_info["stop"] = stop tsl_info["th"] = th tsl_info["exit_price"] = exit_price tsl_info["exit_size"] = exit_size tsl_info["exit_size_type"] = exit_size_type tsl_info["exit_type"] = exit_type tsl_info["order_type"] = order_type tsl_info["limit_delta"] = limit_delta tsl_info["delta_format"] = delta_format tsl_info["ladder"] = ladder tsl_info["step"] = step tsl_info["step_idx"] = step_idx @register_jitted(cache=True) def clear_tsl_info_nb(tsl_info: tp.Record) -> None: """Clear TSL/TTP order information.""" tsl_info["init_idx"] = -1 tsl_info["init_price"] = np.nan tsl_info["init_position"] = np.nan tsl_info["peak_idx"] = -1 tsl_info["peak_price"] = np.nan tsl_info["stop"] = np.nan tsl_info["th"] = np.nan tsl_info["exit_price"] = -1 tsl_info["exit_size"] = np.nan tsl_info["exit_size_type"] = -1 tsl_info["exit_type"] = -1 tsl_info["order_type"] = -1 tsl_info["limit_delta"] = np.nan tsl_info["delta_format"] = -1 tsl_info["ladder"] = -1 tsl_info["step"] = -1 tsl_info["step_idx"] = -1 @register_jitted(cache=True) def set_tp_info_nb( tp_info: tp.Record, init_idx: int, init_price: float = -np.inf, init_position: float = np.nan, stop: float = np.nan, exit_price: float = StopExitPrice.Stop, exit_size: float = np.nan, exit_size_type: int = -1, exit_type: int = StopExitType.Close, order_type: int = OrderType.Market, limit_delta: float = np.nan, delta_format: int = DeltaFormat.Percent, ladder: int = StopLadderMode.Disabled, step: int = -1, step_idx: int = -1, ) -> None: """Set TP order information. See `vectorbtpro.portfolio.enums.tp_info_dt`.""" tp_info["init_idx"] = init_idx tp_info["init_price"] = init_price tp_info["init_position"] = init_position tp_info["stop"] = stop tp_info["exit_price"] = exit_price tp_info["exit_size"] = exit_size tp_info["exit_size_type"] = exit_size_type tp_info["exit_type"] = exit_type tp_info["order_type"] = order_type tp_info["limit_delta"] = limit_delta tp_info["delta_format"] = delta_format tp_info["ladder"] = ladder tp_info["step"] = step tp_info["step_idx"] = step_idx @register_jitted(cache=True) def clear_tp_info_nb(tp_info: tp.Record) -> None: """Clear TP order information.""" tp_info["init_idx"] = -1 tp_info["init_price"] = np.nan tp_info["init_position"] = np.nan tp_info["stop"] = np.nan tp_info["exit_price"] = -1 tp_info["exit_size"] = np.nan tp_info["exit_size_type"] = -1 tp_info["exit_type"] = -1 tp_info["order_type"] = -1 tp_info["limit_delta"] = np.nan tp_info["delta_format"] = -1 tp_info["ladder"] = -1 tp_info["step"] = -1 tp_info["step_idx"] = -1 @register_jitted(cache=True) def set_time_info_nb( time_info: tp.Record, init_idx: int, init_position: float = np.nan, stop: int = -1, exit_price: float = StopExitPrice.Stop, exit_size: float = np.nan, exit_size_type: int = -1, exit_type: int = StopExitType.Close, order_type: int = OrderType.Market, limit_delta: float = np.nan, delta_format: int = DeltaFormat.Percent, time_delta_format: int = TimeDeltaFormat.Index, ladder: int = StopLadderMode.Disabled, step: int = -1, step_idx: int = -1, ) -> None: """Set time order information. See `vectorbtpro.portfolio.enums.time_info_dt`.""" time_info["init_idx"] = init_idx time_info["init_position"] = init_position time_info["stop"] = stop time_info["exit_price"] = exit_price time_info["exit_size"] = exit_size time_info["exit_size_type"] = exit_size_type time_info["exit_type"] = exit_type time_info["order_type"] = order_type time_info["limit_delta"] = limit_delta time_info["delta_format"] = delta_format time_info["time_delta_format"] = time_delta_format time_info["ladder"] = ladder time_info["step"] = step time_info["step_idx"] = step_idx @register_jitted(cache=True) def clear_time_info_nb(time_info: tp.Record) -> None: """Clear time order information.""" time_info["init_idx"] = -1 time_info["init_position"] = np.nan time_info["stop"] = -1 time_info["exit_price"] = -1 time_info["exit_size"] = np.nan time_info["exit_size_type"] = -1 time_info["exit_type"] = -1 time_info["order_type"] = -1 time_info["limit_delta"] = np.nan time_info["delta_format"] = -1 time_info["time_delta_format"] = -1 time_info["ladder"] = -1 time_info["step"] = -1 time_info["step_idx"] = -1 @register_jitted(cache=True) def get_limit_info_target_price_nb(limit_info: tp.Record) -> float: """Get target price from limit order information.""" if not is_limit_info_active_nb(limit_info): return np.nan if limit_info["init_size"] == 0: raise ValueError("Limit order size cannot be zero") size = get_diraware_size_nb(limit_info["init_size"], limit_info["init_direction"]) hit_below = (size > 0 and not limit_info["reverse"]) or (size < 0 and limit_info["reverse"]) return resolve_limit_price_nb( init_price=limit_info["init_price"], limit_delta=limit_info["delta"], delta_format=limit_info["delta_format"], hit_below=hit_below, ) @register_jitted def get_sl_info_target_price_nb(sl_info: tp.Record, position_now: float) -> float: """Get target price from SL order information.""" if not is_stop_info_active_nb(sl_info): return np.nan hit_below = position_now > 0 return resolve_stop_price_nb( init_price=sl_info["init_price"], stop=sl_info["stop"], delta_format=sl_info["delta_format"], hit_below=hit_below, ) @register_jitted def get_tsl_info_target_price_nb(tsl_info: tp.Record, position_now: float) -> float: """Get target price from TSL/TTP order information.""" if not is_stop_info_active_nb(tsl_info): return np.nan hit_below = position_now > 0 return resolve_stop_price_nb( init_price=tsl_info["peak_price"], stop=tsl_info["stop"], delta_format=tsl_info["delta_format"], hit_below=hit_below, ) @register_jitted def get_tp_info_target_price_nb(tp_info: tp.Record, position_now: float) -> float: """Get target price from TP order information.""" if not is_stop_info_active_nb(tp_info): return np.nan hit_below = position_now < 0 return resolve_stop_price_nb( init_price=tp_info["init_price"], stop=tp_info["stop"], delta_format=tp_info["delta_format"], hit_below=hit_below, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled context helper functions for portfolio simulation.""" from vectorbtpro.base.flex_indexing import flex_select_col_nb from vectorbtpro.portfolio.nb import records as pf_records_nb from vectorbtpro.portfolio.nb.core import * from vectorbtpro.portfolio.nb.iter_ import select_nb from vectorbtpro.records import nb as records_nb # ############# Position ############# # @register_jitted def get_col_position_nb(c: tp.NamedTuple, col: int) -> float: """Get position of a column.""" return c.last_position[col] @register_jitted def get_position_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get position of the current column.""" return get_col_position_nb(c, c.col) @register_jitted def col_in_position_nb(c: tp.NamedTuple, col: int) -> bool: """Check whether a column is in a position.""" position = get_col_position_nb(c, col) return position != 0 @register_jitted def in_position_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> bool: """Check whether the current column is in a position.""" return col_in_position_nb(c, c.col) @register_jitted def col_in_long_position_nb(c: tp.NamedTuple, col: int) -> bool: """Check whether a column is in a long position.""" position = get_col_position_nb(c, col) return position > 0 @register_jitted def in_long_position_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> bool: """Check whether the current column is in a long position.""" return col_in_long_position_nb(c, c.col) @register_jitted def col_in_short_position_nb(c: tp.NamedTuple, col: int) -> bool: """Check whether a column is in a short position.""" position = get_col_position_nb(c, col) return position < 0 @register_jitted def in_short_position_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> bool: """Check whether the current column is in a short position.""" return col_in_short_position_nb(c, c.col) @register_jitted def get_n_active_positions_nb( c: tp.Union[ GroupContext, SegmentContext, OrderContext, PostOrderContext, FlexOrderContext, SignalSegmentContext, SignalContext, PostSignalContext, ], all_groups: bool = False, ) -> int: """Get the number of active positions in the current group (regardless of cash sharing). To calculate across all groups, set `all_groups` to True.""" n_active_positions = 0 if all_groups: for col in range(c.target_shape[1]): if c.last_position[col] != 0: n_active_positions += 1 else: for col in range(c.from_col, c.to_col): if c.last_position[col] != 0: n_active_positions += 1 return n_active_positions # ############# Cash ############# # @register_jitted def get_col_cash_nb(c: tp.NamedTuple, col: int) -> float: """Get cash of a column.""" if c.cash_sharing: raise ValueError("Cannot get cash of a single column from a group with cash sharing. " "Use get_group_cash_nb.") return c.last_cash[col] @register_jitted def get_group_cash_nb(c: tp.NamedTuple, group: int) -> float: """Get cash of a group.""" if c.cash_sharing: return c.last_cash[group] cash = 0.0 from_col = 0 for g in range(len(c.group_lens)): to_col = from_col + c.group_lens[g] if g == group: for col in range(from_col, to_col): cash += c.last_cash[col] break from_col = to_col return cash @register_jitted def get_cash_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get cash of the current column or group with cash sharing.""" if c.cash_sharing: return get_group_cash_nb(c, c.group) return get_col_cash_nb(c, c.col) # ############# Debt ############# # @register_jitted def get_col_debt_nb(c: tp.NamedTuple, col: int) -> float: """Get debt of a column.""" return c.last_debt[col] @register_jitted def get_debt_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get debt of the current column.""" return get_col_debt_nb(c, c.col) # ############# Locked cash ############# # @register_jitted def get_col_locked_cash_nb(c: tp.NamedTuple, col: int) -> float: """Get locked cash of a column.""" return c.last_locked_cash[col] @register_jitted def get_locked_cash_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get locked cash of the current column.""" return get_col_locked_cash_nb(c, c.col) # ############# Free cash ############# # @register_jitted def get_col_free_cash_nb(c: tp.NamedTuple, col: int) -> float: """Get free cash of a column.""" if c.cash_sharing: raise ValueError( "Cannot get free cash of a single column from a group with cash sharing. " "Use get_group_free_cash_nb." ) return c.last_free_cash[col] @register_jitted def get_group_free_cash_nb(c: tp.NamedTuple, group: int) -> float: """Get free cash of a group.""" if c.cash_sharing: return c.last_free_cash[group] free_cash = 0.0 from_col = 0 for g in range(len(c.group_lens)): to_col = from_col + c.group_lens[g] if g == group: for col in range(from_col, to_col): free_cash += c.last_free_cash[col] break from_col = to_col return free_cash @register_jitted def get_free_cash_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get free cash of the current column or group with cash sharing.""" if c.cash_sharing: return get_group_free_cash_nb(c, c.group) return get_col_free_cash_nb(c, c.col) @register_jitted def col_has_free_cash_nb(c: tp.NamedTuple, col: int) -> float: """Check whether a column has free cash.""" return get_col_free_cash_nb(c, col) > 0 @register_jitted def group_has_free_cash_nb(c: tp.NamedTuple, group: int) -> float: """Check whether a group has free cash.""" return get_group_free_cash_nb(c, group) > 0 @register_jitted def has_free_cash_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> bool: """Check whether the current column or group with cash sharing has free cash.""" if c.cash_sharing: return group_has_free_cash_nb(c, c.group) return col_has_free_cash_nb(c, c.col) # ############# Valuation price ############# # @register_jitted def get_col_val_price_nb(c: tp.NamedTuple, col: int) -> float: """Get valuation price of a column.""" return c.last_val_price[col] @register_jitted def get_val_price_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get valuation price of the current column.""" return get_col_val_price_nb(c, c.col) # ############# Value ############# # @register_jitted def get_col_value_nb(c: tp.NamedTuple, col: int) -> float: """Get value of a column.""" if c.cash_sharing: raise ValueError( "Cannot get value of a single column from a group with cash sharing. " "Use get_group_value_nb." ) return c.last_value[col] @register_jitted def get_group_value_nb(c: tp.NamedTuple, group: int) -> float: """Get value of a group.""" if c.cash_sharing: return c.last_value[group] value = 0.0 from_col = 0 for g in range(len(c.group_lens)): to_col = from_col + c.group_lens[g] if g == group: for col in range(from_col, to_col): value += c.last_value[col] break from_col = to_col return value @register_jitted def get_value_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get value of the current column or group with cash sharing.""" if c.cash_sharing: return get_group_value_nb(c, c.group) return get_col_value_nb(c, c.col) # ############# Leverage ############# # @register_jitted def get_col_leverage_nb(c: tp.NamedTuple, col: int) -> float: """Get leverage of a column.""" position = get_col_position_nb(c, col) debt = get_col_debt_nb(c, col) locked_cash = get_col_locked_cash_nb(c, col) if locked_cash == 0: return np.nan leverage = debt / locked_cash if position > 0: leverage += 1 return leverage @register_jitted def get_leverage_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get leverage of the current column.""" return get_col_leverage_nb(c, c.col) # ############# Allocation ############# # @register_jitted def get_col_position_value_nb(c: tp.NamedTuple, col: int) -> float: """Get position value of a column.""" position = get_col_position_nb(c, col) val_price = get_col_val_price_nb(c, col) if position == 0: return 0.0 return position * val_price @register_jitted def get_group_position_value_nb(c: tp.NamedTuple, group: int) -> float: """Get position value of a group.""" value = 0.0 from_col = 0 for g in range(len(c.group_lens)): to_col = from_col + c.group_lens[g] if g == group: for col in range(from_col, to_col): value += get_col_position_value_nb(c, col) break from_col = to_col return value @register_jitted def get_position_value_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get position value of the current column.""" return get_col_position_value_nb(c, c.col) @register_jitted def get_col_allocation_nb(c: tp.NamedTuple, col: int, group: tp.Optional[int] = None) -> float: """Get allocation of a column in its group.""" position_value = get_col_position_value_nb(c, col) if group is None: from_col = 0 found = False for _group in range(len(c.group_lens)): to_col = from_col + c.group_lens[_group] if from_col <= col < to_col: found = True break from_col = to_col if not found: raise ValueError("Column out of bounds") else: _group = group value = get_group_value_nb(c, _group) if position_value == 0: return 0.0 if value <= 0: return np.nan return position_value / value @register_jitted def get_allocation_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> float: """Get allocation of the current column in the current group.""" return get_col_allocation_nb(c, c.col, group=c.group) # ############# Orders ############# # @register_jitted def get_col_order_count_nb(c: tp.NamedTuple, col: int) -> int: """Get number of order records for a column.""" return c.order_counts[col] @register_jitted def get_order_count_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> int: """Get number of order records for the current column.""" return get_col_order_count_nb(c, c.col) @register_jitted def get_col_order_records_nb(c: tp.NamedTuple, col: int) -> tp.RecordArray: """Get order records for a column.""" order_count = get_col_order_count_nb(c, col) return c.order_records[:order_count, col] @register_jitted def get_order_records_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> tp.RecordArray: """Get order records for the current column.""" return get_col_order_records_nb(c, c.col) @register_jitted def col_has_orders_nb(c: tp.NamedTuple, col: int) -> bool: """Check whether there is any order in a column.""" return get_col_order_count_nb(c, col) > 0 @register_jitted def has_orders_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> bool: """Check whether there is any order in the current column.""" return col_has_orders_nb(c, c.col) @register_jitted def get_col_last_order_nb(c: tp.NamedTuple, col: int) -> tp.Record: """Get the last order in a column.""" if not col_has_orders_nb(c, col): raise ValueError("There are no orders. Check for any orders first.") return get_col_order_records_nb(c, col)[-1] @register_jitted def get_last_order_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], ) -> tp.Record: """Get the last order in the current column.""" return get_col_last_order_nb(c, c.col) # ############# Order result ############# # @register_jitted def order_filled_nb( c: tp.Union[ PostOrderContext, PostSignalContext, ] ) -> bool: """Check whether the order was filled.""" return c.order_result.status == OrderStatus.Filled @register_jitted def order_opened_position_nb( c: tp.Union[ PostOrderContext, PostSignalContext, ] ) -> bool: """Check whether the order has opened a new position.""" position_now = get_position_nb(c) return order_reversed_position_nb(c) or (c.position_before == 0 and position_now != 0) @register_jitted def order_increased_position_nb( c: tp.Union[ PostOrderContext, PostSignalContext, ] ) -> bool: """Check whether the order has opened or increased an existing position.""" position_now = get_position_nb(c) return order_opened_position_nb(c) or ( np.sign(position_now) == np.sign(c.position_before) and abs(position_now) > abs(c.position_before) ) @register_jitted def order_decreased_position_nb( c: tp.Union[ PostOrderContext, PostSignalContext, ] ) -> bool: """Check whether the order has decreased or closed an existing position.""" position_now = get_position_nb(c) return ( order_closed_position_nb(c) or order_reversed_position_nb(c) or (np.sign(position_now) == np.sign(c.position_before) and abs(position_now) < abs(c.position_before)) ) @register_jitted def order_closed_position_nb( c: tp.Union[ PostOrderContext, PostSignalContext, ] ) -> bool: """Check whether the order has closed out an existing position.""" position_now = get_position_nb(c) return c.position_before != 0 and position_now == 0 @register_jitted def order_reversed_position_nb( c: tp.Union[ PostOrderContext, PostSignalContext, ] ) -> bool: """Check whether the order has reversed an existing position.""" position_now = get_position_nb(c) return c.position_before != 0 and position_now != 0 and np.sign(c.position_before) != np.sign(position_now) # ############# Limit orders ############# # @register_jitted def get_col_limit_info_nb(c: tp.NamedTuple, col: int) -> tp.Record: """Get limit order information of a column.""" return c.last_limit_info[col] @register_jitted def get_limit_info_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> tp.Record: """Get limit order information of the current column.""" return get_col_limit_info_nb(c, c.col) @register_jitted def get_col_limit_target_price_nb(c: tp.NamedTuple, col: int) -> float: """Get target price of limit order in a column.""" if not col_in_position_nb(c, col): return np.nan limit_info = get_col_limit_info_nb(c, col) return get_limit_info_target_price_nb(limit_info) @register_jitted def get_limit_target_price_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> float: """Get target price of limit order in the current column.""" return get_col_limit_target_price_nb(c, c.col) # ############# Stop orders ############# # @register_jitted def get_col_sl_info_nb(c: tp.NamedTuple, col: int) -> tp.Record: """Get SL order information of a column.""" return c.last_sl_info[col] @register_jitted def get_sl_info_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> tp.Record: """Get SL order information of the current column.""" return get_col_sl_info_nb(c, c.col) @register_jitted def get_col_sl_target_price_nb(c: tp.NamedTuple, col: int) -> float: """Get target price of SL order in a column.""" if not col_in_position_nb(c, col): return np.nan position = get_col_position_nb(c, col) sl_info = get_col_sl_info_nb(c, col) return get_sl_info_target_price_nb(sl_info, position) @register_jitted def get_sl_target_price_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> float: """Get target price of SL order in the current column.""" return get_col_sl_target_price_nb(c, c.col) @register_jitted def get_col_tsl_info_nb(c: tp.NamedTuple, col: int) -> tp.Record: """Get TSL/TTP order information of a column.""" return c.last_tsl_info[col] @register_jitted def get_tsl_info_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> tp.Record: """Get TSL/TTP order information of the current column.""" return get_col_tsl_info_nb(c, c.col) @register_jitted def get_col_tsl_target_price_nb(c: tp.NamedTuple, col: int) -> float: """Get target price of TSL/TTP order in a column.""" if not col_in_position_nb(c, col): return np.nan position = get_col_position_nb(c, col) tsl_info = get_col_tsl_info_nb(c, col) return get_tsl_info_target_price_nb(tsl_info, position) @register_jitted def get_tsl_target_price_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> float: """Get target price of TSL/TTP order in the current column.""" return get_col_tsl_target_price_nb(c, c.col) @register_jitted def get_col_tp_info_nb(c: tp.NamedTuple, col: int) -> tp.Record: """Get TP order information of a column.""" return c.last_tp_info[col] @register_jitted def get_tp_info_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> tp.Record: """Get TP order information of the current column.""" return get_col_tp_info_nb(c, c.col) @register_jitted def get_col_tp_target_price_nb(c: tp.NamedTuple, col: int) -> float: """Get target price of TP order in a column.""" if not col_in_position_nb(c, col): return np.nan position = get_col_position_nb(c, col) tp_info = get_col_tp_info_nb(c, col) return get_tp_info_target_price_nb(tp_info, position) @register_jitted def get_tp_target_price_nb( c: tp.Union[ SignalContext, PostSignalContext, ], ) -> float: """Get target price of TP order in the current column.""" return get_col_tp_target_price_nb(c, c.col) # ############# Trades ############# # @register_jitted def get_col_entry_trade_records_nb( c: tp.NamedTuple, col: int, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get entry trade records of a column up to this point.""" order_records = get_col_order_records_nb(c, col) col_map = records_nb.col_map_nb(order_records["col"], c.target_shape[1]) close = flex_select_col_nb(c.close, col) entry_trades = pf_records_nb.get_entry_trades_nb( order_records, close[: c.i + 1], col_map, init_position=init_position, init_price=init_price, ) return entry_trades @register_jitted def get_entry_trade_records_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get entry trade records of the current column up to this point.""" return get_col_entry_trade_records_nb(c, c.col, init_position=init_position, init_price=init_price) @register_jitted def get_col_exit_trade_records_nb( c: tp.NamedTuple, col: int, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get exit trade records of a column up to this point.""" order_records = get_col_order_records_nb(c, col) col_map = records_nb.col_map_nb(order_records["col"], c.target_shape[1]) close = flex_select_col_nb(c.close, col) exit_trades = pf_records_nb.get_exit_trades_nb( order_records, close[: c.i + 1], col_map, init_position=init_position, init_price=init_price, ) return exit_trades @register_jitted def get_exit_trade_records_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get exit trade records of the current column up to this point.""" return get_col_exit_trade_records_nb(c, c.col, init_position=init_position, init_price=init_price) @register_jitted def get_col_position_records_nb( c: tp.NamedTuple, col: int, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get position records of a column up to this point.""" exit_trade_records = get_col_exit_trade_records_nb(c, col, init_position=init_position, init_price=init_price) col_map = records_nb.col_map_nb(exit_trade_records["col"], c.target_shape[1]) position_records = pf_records_nb.get_positions_nb(exit_trade_records, col_map) return position_records @register_jitted def get_position_records_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, ) -> tp.Array1d: """Get position records of the current column up to this point.""" return get_col_position_records_nb(c, c.col, init_position=init_position, init_price=init_price) # ############# Simulation ############# # @register_jitted def stop_group_sim_nb(c: tp.NamedTuple, group: int) -> None: """Stop the simulation of a group.""" c.sim_end[group] = c.i + 1 @register_jitted def stop_sim_nb( c: tp.Union[ SegmentContext, OrderContext, PostOrderContext, FlexOrderContext, SignalSegmentContext, SignalContext, PostSignalContext, ], ) -> None: """Stop the simulation of the current group.""" stop_group_sim_nb(c, c.group) # ############# Ordering ############# # @register_jitted def get_exec_state_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], val_price: tp.Optional[int] = None, ) -> ExecState: """Get execution state.""" if val_price is not None: _val_price = float(val_price) value = float( update_value_nb( cash_before=get_cash_nb(c), cash_now=get_cash_nb(c), position_before=get_position_nb(c), position_now=get_position_nb(c), val_price_before=get_val_price_nb(c), val_price_now=_val_price, value_before=get_value_nb(c), ) ) else: _val_price = float(get_val_price_nb(c)) value = float(get_value_nb(c)) return ExecState( cash=get_cash_nb(c), position=get_position_nb(c), debt=get_debt_nb(c), locked_cash=get_locked_cash_nb(c), free_cash=get_free_cash_nb(c), val_price=_val_price, value=value, ) @register_jitted def get_price_area_nb(c: tp.NamedTuple) -> PriceArea: """Get price area.""" return PriceArea( open=select_nb(c, c.open, i=c.i), high=select_nb(c, c.high, i=c.i), low=select_nb(c, c.low, i=c.i), close=select_nb(c, c.close, i=c.i), ) @register_jitted def get_order_size_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], size: float, size_type: int = SizeType.Amount, val_price: tp.Optional[int] = None, ) -> float: """Get order size.""" exec_state = get_exec_state_nb(c, val_price=val_price) if size_type == SizeType.Percent100 or size_type == SizeType.Percent: raise ValueError("Size type Percent(100) is not supported") return resolve_size_nb( size=size, size_type=size_type, position=get_position_nb(c), val_price=exec_state.val_price, value=exec_state.value, )[0] @register_jitted def get_order_value_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], size: float, size_type: int = SizeType.Amount, direction: int = Direction.Both, val_price: tp.Optional[int] = None, ) -> float: """Get (approximate) order value.""" exec_state = get_exec_state_nb(c, val_price=val_price) return approx_order_value_nb( exec_state, size=size, size_type=size_type, direction=direction, ) @register_jitted def get_order_result_nb( c: tp.Union[ OrderContext, PostOrderContext, SignalContext, PostSignalContext, ], order: Order, val_price: tp.Optional[float] = None, update_value: bool = False, ) -> tp.Tuple[OrderResult, ExecState]: """Get order result and new execution state. Doesn't have any effect on the simulation state.""" exec_state = get_exec_state_nb(c, val_price=val_price) price_area = get_price_area_nb(c) return execute_order_nb( exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for portfolio simulation based on an order function.""" from numba import prange from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.portfolio import chunking as portfolio_ch from vectorbtpro.portfolio.nb.core import * from vectorbtpro.portfolio.nb.iter_ import * from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.returns import nb as returns_nb_ from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.array_ import insert_argsort_nb from vectorbtpro.utils.template import RepFunc @register_jitted(cache=True) def calc_group_value_nb( from_col: int, to_col: int, cash_now: float, last_position: tp.Array1d, last_val_price: tp.Array1d, ) -> float: """Calculate group value.""" group_value = cash_now group_len = to_col - from_col for k in range(group_len): col = from_col + k if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] return group_value @register_jitted def calc_ctx_group_value_nb(seg_ctx: SegmentContext) -> float: """Calculate group value from context. Accepts `vectorbtpro.portfolio.enums.SegmentContext`. Best called once from `pre_segment_func_nb`. To set the valuation price, change `last_val_price` of the context in-place. !!! note Cash sharing must be enabled.""" if not seg_ctx.cash_sharing: raise ValueError("Cash sharing must be enabled") return calc_group_value_nb( seg_ctx.from_col, seg_ctx.to_col, seg_ctx.last_cash[seg_ctx.group], seg_ctx.last_position, seg_ctx.last_val_price, ) @register_jitted def sort_call_seq_out_1d_nb( ctx: SegmentContext, size: tp.FlexArray1d, size_type: tp.FlexArray1d, direction: tp.FlexArray1d, order_value_out: tp.Array1d, call_seq_out: tp.Array1d, ) -> None: """Sort call sequence `call_seq_out` based on the value of each potential order. Accepts `vectorbtpro.portfolio.enums.SegmentContext` and other arguments, sorts `call_seq_out` in place, and returns nothing. Arrays `size`, `size_type`, and `direction` utilize flexible indexing; they must be 1-dim arrays that broadcast to `group_len`. The lengths of `order_value_out` and `call_seq_out` must match the number of columns in the group. Array `order_value_out` must be empty and will contain sorted order values after execution. Array `call_seq_out` must be filled with integers ranging from 0 to the number of columns in the group (in this exact order). Best called once from `pre_segment_func_nb`. !!! note Cash sharing must be enabled and `call_seq_out` must follow `CallSeqType.Default`.""" if not ctx.cash_sharing: raise ValueError("Cash sharing must be enabled") group_value_now = calc_ctx_group_value_nb(ctx) group_len = ctx.to_col - ctx.from_col for c in range(group_len): if call_seq_out[c] != c: raise ValueError("call_seq_out must follow CallSeqType.Default") col = ctx.from_col + c _size = flex_select_1d_pc_nb(size, c) _size_type = flex_select_1d_pc_nb(size_type, c) _direction = flex_select_1d_pc_nb(direction, c) if ctx.cash_sharing: cash_now = ctx.last_cash[ctx.group] free_cash_now = ctx.last_free_cash[ctx.group] else: cash_now = ctx.last_cash[col] free_cash_now = ctx.last_free_cash[col] exec_state = ExecState( cash=cash_now, position=ctx.last_position[col], debt=ctx.last_debt[col], locked_cash=ctx.last_locked_cash[col], free_cash=free_cash_now, val_price=ctx.last_val_price[col], value=group_value_now, ) order_value_out[c] = approx_order_value_nb( exec_state, _size, _size_type, _direction, ) # Sort by order value insert_argsort_nb(order_value_out, call_seq_out) @register_jitted def sort_call_seq_1d_nb( ctx: SegmentContext, size: tp.FlexArray1d, size_type: tp.FlexArray1d, direction: tp.FlexArray1d, order_value_out: tp.Array1d, ) -> None: """Sort call sequence attached to `vectorbtpro.portfolio.enums.SegmentContext`. See `sort_call_seq_out_1d_nb`. !!! note Can only be used in non-flexible simulation functions.""" if ctx.call_seq_now is None: raise ValueError("Call sequence array is None. Use sort_call_seq_out_1d_nb to sort a custom array.") sort_call_seq_out_1d_nb(ctx, size, size_type, direction, order_value_out, ctx.call_seq_now) @register_jitted def sort_call_seq_out_nb( ctx: SegmentContext, size: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, order_value_out: tp.Array1d, call_seq_out: tp.Array1d, ) -> None: """Same as `sort_call_seq_out_1d_nb` but with `size`, `size_type`, and `direction` being 2-dim arrays.""" if not ctx.cash_sharing: raise ValueError("Cash sharing must be enabled") group_value_now = calc_ctx_group_value_nb(ctx) group_len = ctx.to_col - ctx.from_col for c in range(group_len): if call_seq_out[c] != c: raise ValueError("call_seq_out must follow CallSeqType.Default") col = ctx.from_col + c _size = select_from_col_nb(ctx, col, size) _size_type = select_from_col_nb(ctx, col, size_type) _direction = select_from_col_nb(ctx, col, direction) if ctx.cash_sharing: cash_now = ctx.last_cash[ctx.group] free_cash_now = ctx.last_free_cash[ctx.group] else: cash_now = ctx.last_cash[col] free_cash_now = ctx.last_free_cash[col] exec_state = ExecState( cash=cash_now, position=ctx.last_position[col], debt=ctx.last_debt[col], locked_cash=ctx.last_locked_cash[col], free_cash=free_cash_now, val_price=ctx.last_val_price[col], value=group_value_now, ) order_value_out[c] = approx_order_value_nb( exec_state, _size, _size_type, _direction, ) # Sort by order value insert_argsort_nb(order_value_out, call_seq_out) @register_jitted def sort_call_seq_nb( ctx: SegmentContext, size: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, order_value_out: tp.Array1d, ) -> None: """Sort call sequence attached to `vectorbtpro.portfolio.enums.SegmentContext`. See `sort_call_seq_out_nb`. !!! note Can only be used in non-flexible simulation functions.""" if ctx.call_seq_now is None: raise ValueError("Call sequence array is None. Use sort_call_seq_out_1d_nb to sort a custom array.") sort_call_seq_out_nb(ctx, size, size_type, direction, order_value_out, ctx.call_seq_now) @register_jitted def try_order_nb(ctx: OrderContext, order: Order) -> tp.Tuple[OrderResult, ExecState]: """Execute an order without persistence.""" exec_state = ExecState( cash=ctx.cash_now, position=ctx.position_now, debt=ctx.debt_now, locked_cash=ctx.locked_cash_now, free_cash=ctx.free_cash_now, val_price=ctx.val_price_now, value=ctx.value_now, ) price_area = PriceArea( open=flex_select_nb(ctx.open, ctx.i, ctx.col), high=flex_select_nb(ctx.high, ctx.i, ctx.col), low=flex_select_nb(ctx.low, ctx.i, ctx.col), close=flex_select_nb(ctx.close, ctx.i, ctx.col), ) return execute_order_nb(exec_state=exec_state, order=order, price_area=price_area) @register_jitted def no_pre_func_nb(c: tp.NamedTuple, *args) -> tp.Args: """Placeholder preprocessing function that forwards received arguments down the stack.""" return args @register_jitted def no_order_func_nb(c: OrderContext, *args) -> Order: """Placeholder order function that returns no order.""" return NoOrder @register_jitted def no_post_func_nb(c: tp.NamedTuple, *args) -> None: """Placeholder postprocessing function that returns nothing.""" return None PreSimFuncT = tp.Callable[[SimulationContext, tp.VarArg()], tp.Args] PostSimFuncT = tp.Callable[[SimulationContext, tp.VarArg()], None] PreGroupFuncT = tp.Callable[[GroupContext, tp.VarArg()], tp.Args] PostGroupFuncT = tp.Callable[[GroupContext, tp.VarArg()], None] PreSegmentFuncT = tp.Callable[[SegmentContext, tp.VarArg()], tp.Args] PostSegmentFuncT = tp.Callable[[SegmentContext, tp.VarArg()], None] OrderFuncT = tp.Callable[[OrderContext, tp.VarArg()], Order] PostOrderFuncT = tp.Callable[[PostOrderContext, tp.VarArg()], None] # % # % # % # @register_jitted # def pre_sim_func_nb( # c: SimulationContext, # *args, # ) -> tp.Args: # """Custom simulation pre-processing function.""" # return args # # # % # % # % # % # % # % # @register_jitted # def post_sim_func_nb( # c: SimulationContext, # *args, # ) -> None: # """Custom simulation post-processing function.""" # return None # # # % # % # % # % # % # % # @register_jitted # def pre_group_func_nb( # c: GroupContext, # *args, # ) -> tp.Args: # """Custom group pre-processing function.""" # return args # # # % # % # % # % # % # % # @register_jitted # def post_group_func_nb( # c: GroupContext, # *args, # ) -> None: # """Custom group post-processing function.""" # return None # # # % # % # % # % # % # % # @register_jitted # def pre_segment_func_nb( # c: SegmentContext, # *args, # ) -> tp.Args: # """Custom segment pre-processing function.""" # return args # # # % # % # % # % # % # % # @register_jitted # def post_segment_func_nb( # c: SegmentContext, # *args, # ) -> None: # """Custom segment post-processing function.""" # return None # # # % # % # % # % # % # % # @register_jitted # def order_func_nb( # c: OrderContext, # *args, # ) -> Order: # """Custom order function.""" # return NoOrder # # # % # % # % # % # % # % # @register_jitted # def post_order_func_nb( # c: PostOrderContext, # *args, # ) -> None: # """Custom order post-processing function.""" # return None # # # % # % # % # %
# % # import vectorbtpro as vbt # from vectorbtpro.portfolio.nb.from_order_func import * # %? import_lines # # # % # %? blocks[pre_sim_func_nb_block] # % blocks["pre_sim_func_nb"] # %? blocks[post_sim_func_nb_block] # % blocks["post_sim_func_nb"] # %? blocks[pre_group_func_nb_block] # % blocks["pre_group_func_nb"] # %? blocks[post_group_func_nb_block] # % blocks["post_group_func_nb"] # %? blocks[pre_segment_func_nb_block] # % blocks["pre_segment_func_nb"] # %? blocks[post_segment_func_nb_block] # % blocks["post_segment_func_nb"] # %? blocks[order_func_nb_block] # % blocks["order_func_nb"] # %? blocks[post_order_func_nb_block] # % blocks["post_order_func_nb"] @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, call_seq=base_ch.array_gl_slicer, init_cash=RepFunc(portfolio_ch.get_init_cash_slicer), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer), cash_earnings=base_ch.flex_array_gl_slicer, segment_mask=base_ch.FlexArraySlicer(axis=1), call_pre_segment=None, call_post_segment=None, pre_sim_func_nb=None, # % None pre_sim_args=ch.ArgsTaker(), post_sim_func_nb=None, # % None post_sim_args=ch.ArgsTaker(), pre_group_func_nb=None, # % None pre_group_args=ch.ArgsTaker(), post_group_func_nb=None, # % None post_group_args=ch.ArgsTaker(), pre_segment_func_nb=None, # % None pre_segment_args=ch.ArgsTaker(), post_segment_func_nb=None, # % None post_segment_args=ch.ArgsTaker(), order_func_nb=None, # % None order_args=ch.ArgsTaker(), post_order_func_nb=None, # % None post_order_args=ch.ArgsTaker(), index=None, freq=None, open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, bm_close=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ffill_val_price=None, update_value=None, fill_pos_info=None, track_value=None, max_order_records=None, max_log_records=None, in_outputs=ch.ArgsTaker(), ), **portfolio_ch.merge_sim_outs_config, setup_id=None, # %? line.replace("None", task_id) ) @register_jitted( tags={"can_parallel"}, cache=False, # % line.replace("False", "True") task_id_or_func=None, # %? line.replace("None", task_id) ) def from_order_func_nb( # %? line.replace("from_order_func_nb", new_func_name) target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, call_seq: tp.Optional[tp.Array2d] = None, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, segment_mask: tp.FlexArray2dLike = True, call_pre_segment: bool = False, call_post_segment: bool = False, pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None pre_sim_args: tp.Args = (), post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None post_sim_args: tp.Args = (), pre_group_func_nb: PreGroupFuncT = no_pre_func_nb, # % None pre_group_args: tp.Args = (), post_group_func_nb: PostGroupFuncT = no_post_func_nb, # % None post_group_args: tp.Args = (), pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None pre_segment_args: tp.Args = (), post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None post_segment_args: tp.Args = (), order_func_nb: OrderFuncT = no_order_func_nb, # % None order_args: tp.Args = (), post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None post_order_args: tp.Args = (), index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, bm_close: tp.FlexArray2dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ffill_val_price: bool = True, update_value: bool = False, fill_pos_info: bool = True, track_value: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, in_outputs: tp.Optional[tp.NamedTuple] = None, ) -> SimulationOutput: """Fill order and log records by iterating over a shape and calling a range of user-defined functions. Starting with initial cash `init_cash`, iterates over each group and column in `target_shape`, and for each data point, generates an order using `order_func_nb`. Tries then to fulfill that order. Upon success, updates the current state including the cash balance and the position. Returns `vectorbtpro.portfolio.enums.SimulationOutput`. As opposed to `from_order_func_rw_nb`, order processing happens in column-major order. Column-major order means processing the entire column/group with all rows before moving to the next one. See [Row- and column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order). Args: target_shape (tuple): See `vectorbtpro.portfolio.enums.SimulationContext.target_shape`. group_lens (array_like of int): See `vectorbtpro.portfolio.enums.SimulationContext.group_lens`. cash_sharing (bool): See `vectorbtpro.portfolio.enums.SimulationContext.cash_sharing`. call_seq (array_like of int): See `vectorbtpro.portfolio.enums.SimulationContext.call_seq`. init_cash (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.init_cash`. init_position (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.init_position`. init_price (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.init_price`. cash_deposits (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.cash_deposits`. cash_earnings (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.cash_earnings`. segment_mask (array_like of bool): See `vectorbtpro.portfolio.enums.SimulationContext.segment_mask`. call_pre_segment (bool): See `vectorbtpro.portfolio.enums.SimulationContext.call_pre_segment`. call_post_segment (bool): See `vectorbtpro.portfolio.enums.SimulationContext.call_post_segment`. pre_sim_func_nb (callable): Function called before simulation. Can be used for creation of global arrays and setting the seed. Must accept `vectorbtpro.portfolio.enums.SimulationContext` and `*pre_sim_args`. Must return a tuple of any content, which is then passed to `pre_group_func_nb` and `post_group_func_nb`. pre_sim_args (tuple): Packed arguments passed to `pre_sim_func_nb`. post_sim_func_nb (callable): Function called after simulation. Must accept `vectorbtpro.portfolio.enums.SimulationContext` and `*post_sim_args`. Must return nothing. post_sim_args (tuple): Packed arguments passed to `post_sim_func_nb`. pre_group_func_nb (callable): Function called before each group. Must accept `vectorbtpro.portfolio.enums.GroupContext`, unpacked tuple from `pre_sim_func_nb`, and `*pre_group_args`. Must return a tuple of any content, which is then passed to `pre_segment_func_nb` and `post_segment_func_nb`. pre_group_args (tuple): Packed arguments passed to `pre_group_func_nb`. post_group_func_nb (callable): Function called after each group. Must accept `vectorbtpro.portfolio.enums.GroupContext`, unpacked tuple from `pre_sim_func_nb`, and `*post_group_args`. Must return nothing. post_group_args (tuple): Packed arguments passed to `post_group_func_nb`. pre_segment_func_nb (callable): Function called before each segment. Called if `segment_mask` or `call_pre_segment` is True. Must accept `vectorbtpro.portfolio.enums.SegmentContext`, unpacked tuple from `pre_group_func_nb`, and `*pre_segment_args`. Must return a tuple of any content, which is then passed to `order_func_nb` and `post_order_func_nb`. This is the right place to change call sequence and set the valuation price. Group re-valuation and update of the open position stats happens right after this function, regardless of whether it has been called. !!! note To change the call sequence of a segment, access `vectorbtpro.portfolio.enums.SegmentContext.call_seq_now` and change it in-place. Make sure to not generate any new arrays as it may negatively impact performance. Assigning `vectorbtpro.portfolio.enums.SegmentContext.call_seq_now` as any other context (named tuple) value is not supported. See `vectorbtpro.portfolio.enums.SegmentContext.call_seq_now`. !!! note You can override elements of `last_val_price` to manipulate group valuation. See `vectorbtpro.portfolio.enums.SimulationContext.last_val_price`. pre_segment_args (tuple): Packed arguments passed to `pre_segment_func_nb`. post_segment_func_nb (callable): Function called after each segment. Called if `segment_mask` or `call_post_segment` is True. Addition of cash_earnings, the final group re-valuation, and the final update of the open position stats happens right before this function, regardless of whether it has been called. The passed context represents the final state of each segment, thus makes sure to do any changes before this function is called. Must accept `vectorbtpro.portfolio.enums.SegmentContext`, unpacked tuple from `pre_group_func_nb`, and `*post_segment_args`. Must return nothing. post_segment_args (tuple): Packed arguments passed to `post_segment_func_nb`. order_func_nb (callable): Order generation function. Used for either generating an order or skipping. Must accept `vectorbtpro.portfolio.enums.OrderContext`, unpacked tuple from `pre_segment_func_nb`, and `*order_args`. Must return `vectorbtpro.portfolio.enums.Order`. !!! note If the returned order has been rejected, there is no way of issuing a new order. You should make sure that the order passes, for example, by using `try_order_nb`. To have a greater freedom in order management, use `from_flex_order_func_nb`. order_args (tuple): Arguments passed to `order_func_nb`. post_order_func_nb (callable): Callback that is called after the order has been processed. Used for checking the order status and doing some post-processing. Must accept `vectorbtpro.portfolio.enums.PostOrderContext`, unpacked tuple from `pre_segment_func_nb`, and `*post_order_args`. Must return nothing. post_order_args (tuple): Arguments passed to `post_order_func_nb`. index (array): See `vectorbtpro.portfolio.enums.SimulationContext.index`. freq (int): See `vectorbtpro.portfolio.enums.SimulationContext.freq`. open (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.open`. high (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.high`. low (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.low`. close (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.close`. bm_close (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.bm_close`. ffill_val_price (bool): See `vectorbtpro.portfolio.enums.SimulationContext.ffill_val_price`. update_value (bool): See `vectorbtpro.portfolio.enums.SimulationContext.update_value`. fill_pos_info (bool): See `vectorbtpro.portfolio.enums.SimulationContext.fill_pos_info`. track_value (bool): See `vectorbtpro.portfolio.enums.SimulationContext.track_value`. max_order_records (int): The max number of order records expected to be filled at each column. max_log_records (int): The max number of log records expected to be filled at each column. in_outputs (bool): See `vectorbtpro.portfolio.enums.SimulationContext.in_outputs`. !!! note Remember that indexing of 2-dim arrays in vectorbt follows that of pandas: `a[i, col]`. !!! warning You can only safely access data of columns that are to the left of the current group and rows that are to the top of the current row within the same group. Other data points have not been processed yet and thus empty. Accessing them will not trigger any errors or warnings, but provide you with arbitrary data (see [np.empty](https://numpy.org/doc/stable/reference/generated/numpy.empty.html)). Call hierarchy: Like most things in the vectorbt universe, simulation is also done by iterating over a (imaginary) frame. This frame consists of two dimensions: time (rows) and assets/features (columns). Each element of this frame is a potential order, which gets generated by calling an order function. The question is: how do we move across this frame to simulate trading? There are two movement patterns: column-major (as done by `from_order_func_nb`) and row-major order (as done by `from_order_func_rw_nb`). In each of these patterns, we are always moving from top to bottom (time axis) and from left to right (asset/feature axis); the only difference between them is across which axis we are moving faster: do we want to process each column first (thus assuming that columns are independent) or each row? Choosing between them is mostly a matter of preference, but it also makes different data being available when generating an order. The frame is further divided into "blocks": columns, groups, rows, segments, and elements. For example, columns can be grouped into groups that may or may not share the same capital. Regardless of capital sharing, each collection of elements within a group and a time step is called a segment, which simply defines a single context (such as shared capital) for one or multiple orders. Each segment can also define a custom sequence (a so-called call sequence) in which orders are executed. You can imagine each of these blocks as a rectangle drawn over different parts of the frame, and having its own context and pre/post-processing function. The pre-processing function is a simple callback that is called before entering the block, and can be provided by the user to, for example, prepare arrays or do some custom calculations. It must return a tuple (can be empty) that is then unpacked and passed as arguments to the pre- and postprocessing function coming next in the call hierarchy. The postprocessing function can be used, for example, to write user-defined arrays such as returns. ```plaintext 1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args) 2. pre_group_out = pre_group_func_nb(GroupContext, *pre_sim_out, *pre_group_args) 3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_group_out, *pre_segment_args) 4. if segment_mask: order = order_func_nb(OrderContext, *pre_segment_out, *order_args) 5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args) ... 6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_group_out, *post_segment_args) ... 7. post_group_func_nb(GroupContext, *pre_sim_out, *post_group_args) ... 8. post_sim_func_nb(SimulationContext, *post_sim_args) ``` Let's demonstrate a frame with one group of two columns and one group of one column, and the following call sequence: ```plaintext array([[0, 1, 0], [1, 0, 0]]) ``` ![](/assets/images/api/from_order_func_nb.svg){: loading=lazy style="width:800px;" } And here is the context information available at each step: ![](/assets/images/api/context_info.svg){: loading=lazy style="width:700px;" } Usage: * Create a group of three assets together sharing 100$ and simulate an equal-weighted portfolio that rebalances every second tick, all without leaving Numba: ```pycon >>> from vectorbtpro import * >>> @njit ... def pre_sim_func_nb(c): ... print('before simulation') ... # Create a temporary array and pass it down the stack ... order_value_out = np.empty(c.target_shape[1], dtype=float_) ... return (order_value_out,) >>> @njit ... def pre_group_func_nb(c, order_value_out): ... print('\\tbefore group', c.group) ... # Forward down the stack (you can omit pre_group_func_nb entirely) ... return (order_value_out,) >>> @njit ... def pre_segment_func_nb(c, order_value_out, size, price, size_type, direction): ... print('\\t\\tbefore segment', c.i) ... for col in range(c.from_col, c.to_col): ... # Here we use order price for group valuation ... c.last_val_price[col] = vbt.pf_nb.select_from_col_nb(c, col, price) ... ... # Reorder call sequence of this segment such that selling orders come first and buying last ... # Rearranges c.call_seq_now based on order value (size, size_type, direction, and val_price) ... # Utilizes flexible indexing using select_from_col_nb (as we did above) ... vbt.pf_nb.sort_call_seq_nb( ... c, ... size, ... size_type, ... direction, ... order_value_out[c.from_col:c.to_col] ... ) ... # Forward nothing ... return () >>> @njit ... def order_func_nb(c, size, price, size_type, direction, fees, fixed_fees, slippage): ... print('\\t\\t\\tcreating order', c.call_idx, 'at column', c.col) ... # Create and return an order ... return vbt.pf_nb.order_nb( ... size=vbt.pf_nb.select_nb(c, size), ... price=vbt.pf_nb.select_nb(c, price), ... size_type=vbt.pf_nb.select_nb(c, size_type), ... direction=vbt.pf_nb.select_nb(c, direction), ... fees=vbt.pf_nb.select_nb(c, fees), ... fixed_fees=vbt.pf_nb.select_nb(c, fixed_fees), ... slippage=vbt.pf_nb.select_nb(c, slippage) ... ) >>> @njit ... def post_order_func_nb(c): ... print('\\t\\t\\t\\torder status:', c.order_result.status) ... return None >>> @njit ... def post_segment_func_nb(c, order_value_out): ... print('\\t\\tafter segment', c.i) ... return None >>> @njit ... def post_group_func_nb(c, order_value_out): ... print('\\tafter group', c.group) ... return None >>> @njit ... def post_sim_func_nb(c): ... print('after simulation') ... return None >>> target_shape = (5, 3) >>> np.random.seed(42) >>> group_lens = np.array([3]) # one group of three columns >>> cash_sharing = True >>> segment_mask = np.array([True, False, True, False, True])[:, None] >>> price = close = np.random.uniform(1, 10, size=target_shape) >>> size = np.array([[1 / target_shape[1]]]) # custom flexible arrays must be 2-dim >>> size_type = np.array([[vbt.pf_enums.SizeType.TargetPercent]]) >>> direction = np.array([[vbt.pf_enums.Direction.LongOnly]]) >>> fees = np.array([[0.001]]) >>> fixed_fees = np.array([[1.]]) >>> slippage = np.array([[0.001]]) >>> sim_out = vbt.pf_nb.from_order_func_nb( ... target_shape, ... group_lens, ... cash_sharing, ... segment_mask=segment_mask, ... pre_sim_func_nb=pre_sim_func_nb, ... post_sim_func_nb=post_sim_func_nb, ... pre_group_func_nb=pre_group_func_nb, ... post_group_func_nb=post_group_func_nb, ... pre_segment_func_nb=pre_segment_func_nb, ... pre_segment_args=(size, price, size_type, direction), ... post_segment_func_nb=post_segment_func_nb, ... order_func_nb=order_func_nb, ... order_args=(size, price, size_type, direction, fees, fixed_fees, slippage), ... post_order_func_nb=post_order_func_nb ... ) before simulation before group 0 before segment 0 creating order 0 at column 0 order status: 0 creating order 1 at column 1 order status: 0 creating order 2 at column 2 order status: 0 after segment 0 before segment 2 creating order 0 at column 1 order status: 0 creating order 1 at column 2 order status: 0 creating order 2 at column 0 order status: 0 after segment 2 before segment 4 creating order 0 at column 0 order status: 0 creating order 1 at column 2 order status: 0 creating order 2 at column 1 order status: 0 after segment 4 after group 0 after simulation >>> pd.DataFrame.from_records(sim_out.order_records) id col idx size price fees side 0 0 0 0 7.626262 4.375232 1.033367 0 1 1 0 2 5.210115 1.524275 1.007942 0 2 2 0 4 7.899568 8.483492 1.067016 1 3 0 1 0 3.488053 9.565985 1.033367 0 4 1 1 2 0.920352 8.786790 1.008087 1 5 2 1 4 10.713236 2.913963 1.031218 0 6 0 2 0 3.972040 7.595533 1.030170 0 7 1 2 2 0.448747 6.403625 1.002874 1 8 2 2 4 12.378281 2.639061 1.032667 0 >>> col_map = vbt.rec_nb.col_map_nb(sim_out.order_records['col'], target_shape[1]) >>> asset_flow = vbt.pf_nb.asset_flow_nb(target_shape, sim_out.order_records, col_map) >>> assets = vbt.pf_nb.assets_nb(asset_flow) >>> asset_value = vbt.pf_nb.asset_value_nb(close, assets) >>> vbt.Scatter(data=asset_value).fig.show() ``` ![](/assets/images/api/from_order_func_nb_example.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_order_func_nb_example.dark.svg#only-dark){: .iimg loading=lazy } Note that the last order in a group with cash sharing is always disadvantaged as it has a bit less funds than the previous orders due to costs, which are not included when valuating the group. """ check_group_lens_nb(group_lens, target_shape[1]) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) segment_mask_ = to_2d_array_nb(np.asarray(segment_mask)) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) bm_close_ = to_2d_array_nb(np.asarray(bm_close)) order_records, log_records = prepare_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_pos_info = prepare_last_pos_info_nb( target_shape, init_position=init_position_, init_price=init_price_, fill_pos_info=fill_pos_info, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full_like(last_position, 0.0) last_locked_cash = np.full_like(last_position, 0.0) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) temp_call_seq = np.empty(target_shape[1], dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) # Call function before the simulation pre_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col # Call function before the group pre_group_ctx = GroupContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, ) pre_group_out = pre_group_func_nb(pre_group_ctx, *pre_sim_out, *pre_group_args) _sim_start = sim_start_[group] _sim_end = sim_end_[group] for i in range(_sim_start, _sim_end): if call_seq is None: for c in range(group_len): temp_call_seq[c] = c call_seq_now = temp_call_seq[:group_len] else: call_seq_now = call_seq[i, from_col:to_col] if track_value: # Update valuation price using current open for col in range(from_col, to_col): _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group]) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col]) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? is_segment_active = flex_select_nb(segment_mask_, i, group) if call_pre_segment or is_segment_active: # Call function before the segment pre_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, ) pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_group_out, *pre_segment_args) # Add cash if cash_sharing: _cash_deposits = flex_select_nb(cash_deposits_, i, group) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits else: for col in range(from_col, to_col): _cash_deposits = flex_select_nb(cash_deposits_, i, col) last_cash[col] += _cash_deposits last_free_cash[col] += _cash_deposits last_cash_deposits[col] = _cash_deposits if track_value: # Update value and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? if is_segment_active: for k in range(group_len): if cash_sharing: c = call_seq_now[k] if c >= group_len: raise ValueError("Call index out of bounds of the group") else: c = k col = from_col + c # Get current values position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] pos_info_now = last_pos_info[col] if cash_sharing: cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] cash_deposits_now = last_cash_deposits[group] else: cash_now = last_cash[col] free_cash_now = last_free_cash[col] value_now = last_value[col] return_now = last_return[col] cash_deposits_now = last_cash_deposits[col] # Generate the next order order_ctx = OrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, col=col, call_idx=k, cash_now=cash_now, position_now=position_now, debt_now=debt_now, locked_cash_now=locked_cash_now, free_cash_now=free_cash_now, val_price_now=val_price_now, value_now=value_now, return_now=return_now, pos_info_now=pos_info_now, ) order = order_func_nb(order_ctx, *pre_segment_out, *order_args) if not track_value: if ( order.size_type == SizeType.Value or order.size_type == SizeType.TargetValue or order.size_type == SizeType.TargetPercent ): raise ValueError("Cannot use size type that depends on not tracked value") # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash if track_value: val_price_now = new_exec_state.val_price value_now = new_exec_state.value if cash_sharing: return_now = returns_nb_.get_return_nb( prev_close_value[group], value_now - cash_deposits_now, ) else: return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now) # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if cash_sharing: last_cash[group] = cash_now last_free_cash[group] = free_cash_now else: last_cash[col] = cash_now last_free_cash[col] = free_cash_now if track_value: if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_value[group] = value_now last_return[group] = return_now else: last_value[col] = value_now last_return[col] = return_now # Update position record if fill_pos_info: if order_result.status == OrderStatus.Filled: if order_counts[col] > 0: order_id = order_records["id"][order_counts[col] - 1, col] else: order_id = -1 update_pos_info_nb( pos_info_now, i, col, exec_state.position, position_now, order_result, order_id, ) # Post-order callback post_order_ctx = PostOrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, col=col, call_idx=k, cash_before=exec_state.cash, position_before=exec_state.position, debt_before=exec_state.debt, locked_cash_before=exec_state.locked_cash, free_cash_before=exec_state.free_cash, val_price_before=exec_state.val_price, value_before=exec_state.value, order_result=order_result, cash_now=cash_now, position_now=position_now, debt_now=debt_now, locked_cash_now=locked_cash_now, free_cash_now=free_cash_now, val_price_now=val_price_now, value_now=value_now, return_now=return_now, pos_info_now=pos_info_now, ) post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args) # NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows # Add earnings in cash for col in range(from_col, to_col): _cash_earnings = flex_select_nb(cash_earnings_, i, col) if cash_sharing: last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings else: last_cash[col] += _cash_earnings last_free_cash[col] += _cash_earnings if track_value: # Update valuation price using current close for col in range(from_col, to_col): _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) prev_close_value[group] = last_value[group] else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) prev_close_value[col] = last_value[col] # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? if call_post_segment or is_segment_active: # Call function before the segment post_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, ) post_segment_func_nb(post_seg_ctx, *pre_group_out, *post_segment_args) if i >= sim_end_[group] - 1: break # Call function after the group post_group_ctx = GroupContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, ) post_group_func_nb(post_group_ctx, *pre_sim_out, *post_group_args) # Call function after the simulation post_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) post_sim_func_nb(post_sim_ctx, *post_sim_args) sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) # %
PreRowFuncT = tp.Callable[[RowContext, tp.VarArg()], tp.Args] PostRowFuncT = tp.Callable[[RowContext, tp.VarArg()], None] # % # % # % # @register_jitted # def pre_row_func_nb( # c: RowContext, # *args, # ) -> tp.Args: # """Custom row pre-processing function.""" # return args # # # % # % # % # % # % # % # @register_jitted # def post_row_func_nb( # c: RowContext, # *args, # ) -> None: # """Custom row post-processing function.""" # return None # # # % # % # % # %
# % # import vectorbtpro as vbt # from vectorbtpro.portfolio.nb.from_order_func import * # %? import_lines # # # % # %? blocks[pre_sim_func_nb_block] # % blocks["pre_sim_func_nb"] # %? blocks[post_sim_func_nb_block] # % blocks["post_sim_func_nb"] # %? blocks[pre_row_func_nb_block] # % blocks["pre_row_func_nb"] # %? blocks[post_row_func_nb_block] # % blocks["post_row_func_nb"] # %? blocks[pre_segment_func_nb_block] # % blocks["pre_segment_func_nb"] # %? blocks[post_segment_func_nb_block] # % blocks["post_segment_func_nb"] # %? blocks[order_func_nb_block] # % blocks["order_func_nb"] # %? blocks[post_order_func_nb_block] # % blocks["post_order_func_nb"] @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, call_seq=base_ch.array_gl_slicer, init_cash=RepFunc(portfolio_ch.get_init_cash_slicer), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer), cash_earnings=base_ch.flex_array_gl_slicer, segment_mask=base_ch.FlexArraySlicer(axis=1), call_pre_segment=None, call_post_segment=None, pre_sim_func_nb=None, # % None pre_sim_args=ch.ArgsTaker(), post_sim_func_nb=None, # % None post_sim_args=ch.ArgsTaker(), pre_row_func_nb=None, # % None pre_row_args=ch.ArgsTaker(), post_row_func_nb=None, # % None post_row_args=ch.ArgsTaker(), pre_segment_func_nb=None, # % None pre_segment_args=ch.ArgsTaker(), post_segment_func_nb=None, # % None post_segment_args=ch.ArgsTaker(), order_func_nb=None, # % None order_args=ch.ArgsTaker(), post_order_func_nb=None, # % None post_order_args=ch.ArgsTaker(), index=None, freq=None, open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, bm_close=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ffill_val_price=None, update_value=None, fill_pos_info=None, track_value=None, max_order_records=None, max_log_records=None, in_outputs=ch.ArgsTaker(), ), **portfolio_ch.merge_sim_outs_config, setup_id=None, # %? line.replace("None", task_id) ) @register_jitted( tags={"can_parallel"}, cache=False, # % line.replace("False", "True") task_id_or_func=None, # %? line.replace("None", task_id) ) def from_order_func_rw_nb( # %? line.replace("from_order_func_rw_nb", new_func_name) target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, call_seq: tp.Optional[tp.Array2d] = None, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, segment_mask: tp.FlexArray2dLike = True, call_pre_segment: bool = False, call_post_segment: bool = False, pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None pre_sim_args: tp.Args = (), post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None post_sim_args: tp.Args = (), pre_row_func_nb: PreRowFuncT = no_pre_func_nb, # % None pre_row_args: tp.Args = (), post_row_func_nb: PostRowFuncT = no_post_func_nb, # % None post_row_args: tp.Args = (), pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None pre_segment_args: tp.Args = (), post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None post_segment_args: tp.Args = (), order_func_nb: OrderFuncT = no_order_func_nb, # % None order_args: tp.Args = (), post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None post_order_args: tp.Args = (), index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, bm_close: tp.FlexArray2dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ffill_val_price: bool = True, update_value: bool = False, fill_pos_info: bool = True, track_value: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, in_outputs: tp.Optional[tp.NamedTuple] = None, ) -> SimulationOutput: """Same as `from_order_func_nb`, but iterates in row-major order. Row-major order means processing the entire row with all groups/columns before moving to the next one. The main difference is that instead of `pre_group_func_nb` it now exposes `pre_row_func_nb`, which is executed per entire row. It must accept `vectorbtpro.portfolio.enums.RowContext`. !!! note Function `pre_row_func_nb` is only called if there is at least on active segment in the row. Functions `pre_segment_func_nb` and `order_func_nb` are only called if their segment is active. If the main task of `pre_row_func_nb` is to activate/deactivate segments, all segments must be activated by default to allow `pre_row_func_nb` to be called. !!! warning You can only safely access data points that are to the left of the current group and rows that are to the top of the current row. Call hierarchy: ```plaintext 1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args) 2. pre_row_out = pre_row_func_nb(RowContext, *pre_sim_out, *pre_row_args) 3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_row_out, *pre_segment_args) 4. if segment_mask: order = order_func_nb(OrderContext, *pre_segment_out, *order_args) 5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args) ... 6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_row_out, *post_segment_args) ... 7. post_row_func_nb(RowContext, *pre_sim_out, *post_row_args) ... 8. post_sim_func_nb(SimulationContext, *post_sim_args) ``` Let's illustrate the same example as in `from_order_func_nb` but adapted for this function: ![](/assets/images/api/from_order_func_rw_nb.svg){: loading=lazy style="width:800px;" } Usage: * Running the same example as in `from_order_func_nb` but adapted for this function: ```pycon >>> @njit ... def pre_row_func_nb(c, order_value_out): ... print('\\tbefore row', c.i) ... # Forward down the stack ... return (order_value_out,) >>> @njit ... def post_row_func_nb(c, order_value_out): ... print('\\tafter row', c.i) ... return None >>> sim_out = vbt.pf_nb.from_order_func_rw_nb( ... target_shape, ... group_lens, ... cash_sharing, ... segment_mask=segment_mask, ... pre_sim_func_nb=pre_sim_func_nb, ... post_sim_func_nb=post_sim_func_nb, ... pre_row_func_nb=pre_row_func_nb, ... post_row_func_nb=post_row_func_nb, ... pre_segment_func_nb=pre_segment_func_nb, ... pre_segment_args=(size, price, size_type, direction), ... post_segment_func_nb=post_segment_func_nb, ... order_func_nb=order_func_nb, ... order_args=(size, price, size_type, direction, fees, fixed_fees, slippage), ... post_order_func_nb=post_order_func_nb ... ) before simulation before row 0 before segment 0 creating order 0 at column 0 order status: 0 creating order 1 at column 1 order status: 0 creating order 2 at column 2 order status: 0 after segment 0 after row 0 before row 1 after row 1 before row 2 before segment 2 creating order 0 at column 1 order status: 0 creating order 1 at column 2 order status: 0 creating order 2 at column 0 order status: 0 after segment 2 after row 2 before row 3 after row 3 before row 4 before segment 4 creating order 0 at column 0 order status: 0 creating order 1 at column 2 order status: 0 creating order 2 at column 1 order status: 0 after segment 4 after row 4 after simulation ``` """ check_group_lens_nb(group_lens, target_shape[1]) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) segment_mask_ = to_2d_array_nb(np.asarray(segment_mask)) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) bm_close_ = to_2d_array_nb(np.asarray(bm_close)) order_records, log_records = prepare_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_pos_info = prepare_last_pos_info_nb( target_shape, init_position=init_position_, init_price=init_price_, fill_pos_info=fill_pos_info, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full_like(last_position, 0.0) last_locked_cash = np.full_like(last_position, 0.0) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) temp_call_seq = np.empty(target_shape[1], dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) # Call function before the simulation pre_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args) _sim_start = sim_start_.min() _sim_end = sim_end_.max() for i in range(_sim_start, _sim_end): # Call function before the row pre_row_ctx = RowContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, i=i, ) pre_row_out = pre_row_func_nb(pre_row_ctx, *pre_sim_out, *pre_row_args) for group in range(len(group_lens)): if i < sim_start_[group] or i >= sim_end_[group]: continue from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col if call_seq is None: for c in range(group_len): temp_call_seq[c] = c call_seq_now = temp_call_seq[:group_len] else: call_seq_now = call_seq[i, from_col:to_col] if track_value: # Update valuation price using current open for col in range(from_col, to_col): _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group]) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col]) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? is_segment_active = flex_select_nb(segment_mask_, i, group) if call_pre_segment or is_segment_active: # Call function before the segment pre_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, ) pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_row_out, *pre_segment_args) # Add cash if cash_sharing: _cash_deposits = flex_select_nb(cash_deposits_, i, group) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits else: for col in range(from_col, to_col): _cash_deposits = flex_select_nb(cash_deposits_, i, col) last_cash[col] += _cash_deposits last_free_cash[col] += _cash_deposits last_cash_deposits[col] = _cash_deposits if track_value: # Update value and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? if is_segment_active: for k in range(group_len): if cash_sharing: c = call_seq_now[k] if c >= group_len: raise ValueError("Call index out of bounds of the group") else: c = k col = from_col + c # Get current values position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] pos_info_now = last_pos_info[col] if cash_sharing: cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] cash_deposits_now = last_cash_deposits[group] else: cash_now = last_cash[col] free_cash_now = last_free_cash[col] value_now = last_value[col] return_now = last_return[col] cash_deposits_now = last_cash_deposits[col] # Generate the next order order_ctx = OrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, col=col, call_idx=k, cash_now=cash_now, position_now=position_now, debt_now=debt_now, locked_cash_now=locked_cash_now, free_cash_now=free_cash_now, val_price_now=val_price_now, value_now=value_now, return_now=return_now, pos_info_now=pos_info_now, ) order = order_func_nb(order_ctx, *pre_segment_out, *order_args) if not track_value: if ( order.size_type == SizeType.Value or order.size_type == SizeType.TargetValue or order.size_type == SizeType.TargetPercent ): raise ValueError("Cannot use size type that depends on not tracked value") # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash if track_value: val_price_now = new_exec_state.val_price value_now = new_exec_state.value if cash_sharing: return_now = returns_nb_.get_return_nb( prev_close_value[group], value_now - cash_deposits_now, ) else: return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now) # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if cash_sharing: last_cash[group] = cash_now last_free_cash[group] = free_cash_now else: last_cash[col] = cash_now last_free_cash[col] = free_cash_now if track_value: if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_value[group] = value_now last_return[group] = return_now else: last_value[col] = value_now last_return[col] = return_now # Update position record if fill_pos_info: if order_result.status == OrderStatus.Filled: if order_counts[col] > 0: order_id = order_records["id"][order_counts[col] - 1, col] else: order_id = -1 update_pos_info_nb( pos_info_now, i, col, exec_state.position, position_now, order_result, order_id, ) # Post-order callback post_order_ctx = PostOrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, col=col, call_idx=k, cash_before=exec_state.cash, position_before=exec_state.position, debt_before=exec_state.debt, locked_cash_before=exec_state.locked_cash, free_cash_before=exec_state.free_cash, val_price_before=exec_state.val_price, value_before=exec_state.value, order_result=order_result, cash_now=cash_now, position_now=position_now, debt_now=debt_now, locked_cash_now=locked_cash_now, free_cash_now=free_cash_now, val_price_now=val_price_now, value_now=value_now, return_now=return_now, pos_info_now=pos_info_now, ) post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args) # NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows # Add earnings in cash for col in range(from_col, to_col): _cash_earnings = flex_select_nb(cash_earnings_, i, col) if cash_sharing: last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings else: last_cash[col] += _cash_earnings last_free_cash[col] += _cash_earnings if track_value: # Update valuation price using current close for col in range(from_col, to_col): _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) prev_close_value[group] = last_value[group] else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) prev_close_value[col] = last_value[col] # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? if call_post_segment or is_segment_active: # Call function after the segment post_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=call_seq_now, ) post_segment_func_nb(post_seg_ctx, *pre_row_out, *post_segment_args) # Call function after the row post_row_ctx = RowContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, i=i, ) post_row_func_nb(post_row_ctx, *pre_sim_out, *post_row_args) sim_end_reached = True for group in range(len(group_lens)): if i < sim_end_[group] - 1: sim_end_reached = False break if sim_end_reached: break # Call function after the simulation post_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=call_seq, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) post_sim_func_nb(post_sim_ctx, *post_sim_args) sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) # %
@register_jitted def no_flex_order_func_nb(c: FlexOrderContext, *args) -> tp.Tuple[int, Order]: """Placeholder flexible order function that returns "break" column and no order.""" return -1, NoOrder FlexOrderFuncT = tp.Callable[[FlexOrderContext, tp.VarArg()], tp.Tuple[int, Order]] # % # % # % # @register_jitted # def flex_order_func_nb( # c: FlexOrderContext, # *args, # ) -> tp.Tuple[int, Order]: # """Custom flexible order function.""" # return -1, NoOrder # # # % # % # % # %
# % # import vectorbtpro as vbt # from vectorbtpro.portfolio.nb.from_order_func import * # %? import_lines # # # % # %? blocks[pre_sim_func_nb_block] # % blocks["pre_sim_func_nb"] # %? blocks[post_sim_func_nb_block] # % blocks["post_sim_func_nb"] # %? blocks[pre_group_func_nb_block] # % blocks["pre_group_func_nb"] # %? blocks[post_group_func_nb_block] # % blocks["post_group_func_nb"] # %? blocks[pre_segment_func_nb_block] # % blocks["pre_segment_func_nb"] # %? blocks[post_segment_func_nb_block] # % blocks["post_segment_func_nb"] # %? blocks[flex_order_func_nb_block] # % blocks["flex_order_func_nb"] # %? blocks[post_order_func_nb_block] # % blocks["post_order_func_nb"] @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, init_cash=RepFunc(portfolio_ch.get_init_cash_slicer), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer), cash_earnings=base_ch.flex_array_gl_slicer, segment_mask=base_ch.FlexArraySlicer(axis=1), call_pre_segment=None, call_post_segment=None, pre_sim_func_nb=None, # % None pre_sim_args=ch.ArgsTaker(), post_sim_func_nb=None, # % None post_sim_args=ch.ArgsTaker(), pre_group_func_nb=None, # % None pre_group_args=ch.ArgsTaker(), post_group_func_nb=None, # % None post_group_args=ch.ArgsTaker(), pre_segment_func_nb=None, # % None pre_segment_args=ch.ArgsTaker(), post_segment_func_nb=None, # % None post_segment_args=ch.ArgsTaker(), flex_order_func_nb=None, # % None flex_order_args=ch.ArgsTaker(), post_order_func_nb=None, # % None post_order_args=ch.ArgsTaker(), index=None, freq=None, open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, bm_close=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ffill_val_price=None, update_value=None, fill_pos_info=None, track_value=None, max_order_records=None, max_log_records=None, in_outputs=ch.ArgsTaker(), ), **portfolio_ch.merge_sim_outs_config, setup_id=None, # %? line.replace("None", task_id) ) @register_jitted( tags={"can_parallel"}, cache=False, # % line.replace("False", "True") task_id_or_func=None, # %? line.replace("None", task_id) ) def from_flex_order_func_nb( # %? line.replace("from_flex_order_func_nb", new_func_name) target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, segment_mask: tp.FlexArray2dLike = True, call_pre_segment: bool = False, call_post_segment: bool = False, pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None pre_sim_args: tp.Args = (), post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None post_sim_args: tp.Args = (), pre_group_func_nb: PreGroupFuncT = no_pre_func_nb, # % None pre_group_args: tp.Args = (), post_group_func_nb: PostGroupFuncT = no_post_func_nb, # % None post_group_args: tp.Args = (), pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None pre_segment_args: tp.Args = (), post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None post_segment_args: tp.Args = (), flex_order_func_nb: FlexOrderFuncT = no_flex_order_func_nb, # % None flex_order_args: tp.Args = (), post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None post_order_args: tp.Args = (), index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, bm_close: tp.FlexArray2dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ffill_val_price: bool = True, update_value: bool = False, fill_pos_info: bool = True, track_value: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, in_outputs: tp.Optional[tp.NamedTuple] = None, ) -> SimulationOutput: """Same as `from_order_func_nb`, but with no predefined call sequence. In contrast to `order_func_nb` in`from_order_func_nb`, `post_order_func_nb` is a segment-level order function that returns a column along with the order, and gets repeatedly called until some condition is met. This allows multiple orders to be issued within a single element and in an arbitrary order. The order function must accept `vectorbtpro.portfolio.enums.FlexOrderContext`, unpacked tuple from `pre_segment_func_nb`, and `*flex_order_args`. Must return column and `vectorbtpro.portfolio.enums.Order`. To break out of the loop, return column of -1. !!! note Since one element can now accommodate multiple orders, you may run into "order_records index out of range" exception. In this case, you must increase `max_order_records`. This cannot be done automatically and dynamically to avoid performance degradation. Call hierarchy: ```plaintext 1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args) 2. pre_group_out = pre_group_func_nb(GroupContext, *pre_sim_out, *pre_group_args) 3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_group_out, *pre_segment_args) while col != -1: 4. if segment_mask: col, order = flex_order_func_nb(FlexOrderContext, *pre_segment_out, *flex_order_args) 5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args) ... 6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_group_out, *post_segment_args) ... 7. post_group_func_nb(GroupContext, *pre_sim_out, *post_group_args) ... 8. post_sim_func_nb(SimulationContext, *post_sim_args) ``` Let's illustrate the same example as in `from_order_func_nb` but adapted for this function: ![](/assets/images/api/from_flex_order_func_nb.svg){: loading=lazy style="width:800px;" } Usage: * The same example as in `from_order_func_nb`: ```pycon >>> from vectorbtpro import * >>> @njit ... def pre_sim_func_nb(c): ... # Create temporary arrays and pass them down the stack ... print('before simulation') ... order_value_out = np.empty(c.target_shape[1], dtype=float_) ... call_seq_out = np.empty(c.target_shape[1], dtype=int_) ... return (order_value_out, call_seq_out) >>> @njit ... def pre_group_func_nb(c, order_value_out, call_seq_out): ... print('\\tbefore group', c.group) ... return (order_value_out, call_seq_out) >>> @njit ... def pre_segment_func_nb(c, order_value_out, call_seq_out, size, price, size_type, direction): ... print('\\t\\tbefore segment', c.i) ... for col in range(c.from_col, c.to_col): ... # Here we use order price for group valuation ... c.last_val_price[col] = vbt.pf_nb.select_from_col_nb(c, col, price) ... ... # Same as for from_order_func_nb, but since we don't have a predefined c.call_seq_now anymore, ... # we need to store our new call sequence somewhere else ... call_seq_out[:] = np.arange(c.group_len) ... vbt.pf_nb.sort_call_seq_out_nb( ... c, ... size, ... size_type, ... direction, ... order_value_out[c.from_col:c.to_col], ... call_seq_out[c.from_col:c.to_col] ... ) ... ... # Forward the sorted call sequence ... return (call_seq_out,) >>> @njit ... def flex_order_func_nb(c, call_seq_out, size, price, size_type, direction, fees, fixed_fees, slippage): ... if c.call_idx < c.group_len: ... col = c.from_col + call_seq_out[c.call_idx] ... print('\\t\\t\\tcreating order', c.call_idx, 'at column', col) ... # # Create and return an order ... return col, vbt.pf_nb.order_nb( ... size=vbt.pf_nb.select_from_col_nb(c, col, size), ... price=vbt.pf_nb.select_from_col_nb(c, col, price), ... size_type=vbt.pf_nb.select_from_col_nb(c, col, size_type), ... direction=vbt.pf_nb.select_from_col_nb(c, col, direction), ... fees=vbt.pf_nb.select_from_col_nb(c, col, fees), ... fixed_fees=vbt.pf_nb.select_from_col_nb(c, col, fixed_fees), ... slippage=vbt.pf_nb.select_from_col_nb(c, col, slippage) ... ) ... # All columns already processed -> break the loop ... print('\\t\\t\\tbreaking out of the loop') ... return -1, vbt.pf_nb.order_nothing_nb() >>> @njit ... def post_order_func_nb(c, call_seq_out): ... print('\\t\\t\\t\\torder status:', c.order_result.status) ... return None >>> @njit ... def post_segment_func_nb(c, order_value_out, call_seq_out): ... print('\\t\\tafter segment', c.i) ... return None >>> @njit ... def post_group_func_nb(c, order_value_out, call_seq_out): ... print('\\tafter group', c.group) ... return None >>> @njit ... def post_sim_func_nb(c): ... print('after simulation') ... return None >>> target_shape = (5, 3) >>> np.random.seed(42) >>> group_lens = np.array([3]) # one group of three columns >>> cash_sharing = True >>> segment_mask = np.array([True, False, True, False, True])[:, None] >>> price = close = np.random.uniform(1, 10, size=target_shape) >>> size = np.array([[1 / target_shape[1]]]) # custom flexible arrays must be 2-dim >>> size_type = np.array([[vbt.pf_enums.SizeType.TargetPercent]]) >>> direction = np.array([[vbt.pf_enums.Direction.LongOnly]]) >>> fees = np.array([[0.001]]) >>> fixed_fees = np.array([[1.]]) >>> slippage = np.array([[0.001]]) >>> sim_out = vbt.pf_nb.from_flex_order_func_nb( ... target_shape, ... group_lens, ... cash_sharing, ... segment_mask=segment_mask, ... pre_sim_func_nb=pre_sim_func_nb, ... post_sim_func_nb=post_sim_func_nb, ... pre_group_func_nb=pre_group_func_nb, ... post_group_func_nb=post_group_func_nb, ... pre_segment_func_nb=pre_segment_func_nb, ... pre_segment_args=(size, price, size_type, direction), ... post_segment_func_nb=post_segment_func_nb, ... flex_order_func_nb=flex_order_func_nb, ... flex_order_args=(size, price, size_type, direction, fees, fixed_fees, slippage), ... post_order_func_nb=post_order_func_nb ... ) before simulation before group 0 before segment 0 creating order 0 at column 0 order status: 0 creating order 1 at column 1 order status: 0 creating order 2 at column 2 order status: 0 breaking out of the loop after segment 0 before segment 2 creating order 0 at column 1 order status: 0 creating order 1 at column 2 order status: 0 creating order 2 at column 0 order status: 0 breaking out of the loop after segment 2 before segment 4 creating order 0 at column 0 order status: 0 creating order 1 at column 2 order status: 0 creating order 2 at column 1 order status: 0 breaking out of the loop after segment 4 after group 0 after simulation ``` """ check_group_lens_nb(group_lens, target_shape[1]) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) segment_mask_ = to_2d_array_nb(np.asarray(segment_mask)) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) bm_close_ = to_2d_array_nb(np.asarray(bm_close)) order_records, log_records = prepare_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_pos_info = prepare_last_pos_info_nb( target_shape, init_position=init_position_, init_price=init_price_, fill_pos_info=fill_pos_info, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full_like(last_position, 0.0) last_locked_cash = np.full_like(last_position, 0.0) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) # Call function before the simulation pre_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col # Call function before the group pre_group_ctx = GroupContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, ) pre_group_out = pre_group_func_nb(pre_group_ctx, *pre_sim_out, *pre_group_args) _sim_start = sim_start_[group] _sim_end = sim_end_[group] for i in range(_sim_start, _sim_end): if track_value: # Update valuation price using current open for col in range(from_col, to_col): _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group]) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col]) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? is_segment_active = flex_select_nb(segment_mask_, i, group) if call_pre_segment or is_segment_active: # Call function before the segment pre_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, ) pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_group_out, *pre_segment_args) # Add cash if cash_sharing: _cash_deposits = flex_select_nb(cash_deposits_, i, group) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits else: for col in range(from_col, to_col): _cash_deposits = flex_select_nb(cash_deposits_, i, col) last_cash[col] += _cash_deposits last_free_cash[col] += _cash_deposits last_cash_deposits[col] = _cash_deposits if track_value: # Update value and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? is_segment_active = flex_select_nb(segment_mask_, i, group) if is_segment_active: call_idx = -1 while True: call_idx += 1 # Generate the next order flex_order_ctx = FlexOrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, call_idx=call_idx, ) col, order = flex_order_func_nb(flex_order_ctx, *pre_segment_out, *flex_order_args) if col == -1: break if col < from_col or col >= to_col: raise ValueError("Column out of bounds of the group") if not track_value: if ( order.size_type == SizeType.Value or order.size_type == SizeType.TargetValue or order.size_type == SizeType.TargetPercent ): raise ValueError("Cannot use size type that depends on not tracked value") # Get current values position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] pos_info_now = last_pos_info[col] if cash_sharing: cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] cash_deposits_now = last_cash_deposits[group] else: cash_now = last_cash[col] free_cash_now = last_free_cash[col] value_now = last_value[col] return_now = last_return[col] cash_deposits_now = last_cash_deposits[col] # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash if track_value: val_price_now = new_exec_state.val_price value_now = new_exec_state.value if cash_sharing: return_now = returns_nb_.get_return_nb( prev_close_value[group], value_now - cash_deposits_now, ) else: return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now) # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_cash[group] = cash_now last_free_cash[group] = free_cash_now last_value[group] = value_now last_return[group] = return_now else: last_cash[col] = cash_now last_free_cash[col] = free_cash_now last_value[col] = value_now last_return[col] = return_now if track_value: if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_value[group] = value_now last_return[group] = return_now else: last_value[col] = value_now last_return[col] = return_now # Update position record if fill_pos_info: if order_result.status == OrderStatus.Filled: if order_counts[col] > 0: order_id = order_records["id"][order_counts[col] - 1, col] else: order_id = -1 update_pos_info_nb( pos_info_now, i, col, exec_state.position, position_now, order_result, order_id, ) # Post-order callback post_order_ctx = PostOrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, col=col, call_idx=call_idx, cash_before=exec_state.cash, position_before=exec_state.position, debt_before=exec_state.debt, locked_cash_before=exec_state.locked_cash, free_cash_before=exec_state.free_cash, val_price_before=exec_state.val_price, value_before=exec_state.value, order_result=order_result, cash_now=cash_now, position_now=position_now, debt_now=debt_now, locked_cash_now=locked_cash_now, free_cash_now=free_cash_now, val_price_now=val_price_now, value_now=value_now, return_now=return_now, pos_info_now=pos_info_now, ) post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args) # NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows # Add earnings in cash for col in range(from_col, to_col): _cash_earnings = flex_select_nb(cash_earnings_, i, col) if cash_sharing: last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings else: last_cash[col] += _cash_earnings last_free_cash[col] += _cash_earnings if track_value: # Update valuation price using current close for col in range(from_col, to_col): _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) prev_close_value[group] = last_value[group] else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) prev_close_value[col] = last_value[col] # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? if call_post_segment or is_segment_active: # Call function before the segment post_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, ) post_segment_func_nb(post_seg_ctx, *pre_group_out, *post_segment_args) if i >= sim_end_[group] - 1: break # Call function after the group post_group_ctx = GroupContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, ) post_group_func_nb(post_group_ctx, *pre_sim_out, *post_group_args) # Call function after the simulation post_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) post_sim_func_nb(post_sim_ctx, *post_sim_args) sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, call_seq=None, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) # %
# %
# % # import vectorbtpro as vbt # from vectorbtpro.portfolio.nb.from_order_func import * # %? import_lines # # # % # %? blocks[pre_sim_func_nb_block] # % blocks["pre_sim_func_nb"] # %? blocks[post_sim_func_nb_block] # % blocks["post_sim_func_nb"] # %? blocks[pre_row_func_nb_block] # % blocks["pre_row_func_nb"] # %? blocks[post_row_func_nb_block] # % blocks["post_row_func_nb"] # %? blocks[pre_segment_func_nb_block] # % blocks["pre_segment_func_nb"] # %? blocks[post_segment_func_nb_block] # % blocks["post_segment_func_nb"] # %? blocks[flex_order_func_nb_block] # % blocks["flex_order_func_nb"] # %? blocks[post_order_func_nb_block] # % blocks["post_order_func_nb"] @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, init_cash=RepFunc(portfolio_ch.get_init_cash_slicer), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer), cash_earnings=base_ch.flex_array_gl_slicer, segment_mask=base_ch.FlexArraySlicer(axis=1), call_pre_segment=None, call_post_segment=None, pre_sim_func_nb=None, # % None pre_sim_args=ch.ArgsTaker(), post_sim_func_nb=None, # % None post_sim_args=ch.ArgsTaker(), pre_row_func_nb=None, # % None pre_row_args=ch.ArgsTaker(), post_row_func_nb=None, # % None post_row_args=ch.ArgsTaker(), pre_segment_func_nb=None, # % None pre_segment_args=ch.ArgsTaker(), post_segment_func_nb=None, # % None post_segment_args=ch.ArgsTaker(), flex_order_func_nb=None, # % None flex_order_args=ch.ArgsTaker(), post_order_func_nb=None, # % None post_order_args=ch.ArgsTaker(), index=None, freq=None, open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, bm_close=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ffill_val_price=None, update_value=None, fill_pos_info=None, track_value=None, max_order_records=None, max_log_records=None, in_outputs=ch.ArgsTaker(), ), **portfolio_ch.merge_sim_outs_config, setup_id=None, # %? line.replace("None", task_id) ) @register_jitted( tags={"can_parallel"}, cache=False, # % line.replace("False", "True") task_id_or_func=None, # %? line.replace("None", task_id) ) def from_flex_order_func_rw_nb( # %? line.replace("from_flex_order_func_rw_nb", new_func_name) target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, segment_mask: tp.FlexArray2dLike = True, call_pre_segment: bool = False, call_post_segment: bool = False, pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None pre_sim_args: tp.Args = (), post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None post_sim_args: tp.Args = (), pre_row_func_nb: PreRowFuncT = no_pre_func_nb, # % None pre_row_args: tp.Args = (), post_row_func_nb: PostRowFuncT = no_post_func_nb, # % None post_row_args: tp.Args = (), pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None pre_segment_args: tp.Args = (), post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None post_segment_args: tp.Args = (), flex_order_func_nb: FlexOrderFuncT = no_flex_order_func_nb, # % None flex_order_args: tp.Args = (), post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None post_order_args: tp.Args = (), index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, bm_close: tp.FlexArray2dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ffill_val_price: bool = True, update_value: bool = False, fill_pos_info: bool = True, track_value: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, in_outputs: tp.Optional[tp.NamedTuple] = None, ) -> SimulationOutput: """Same as `from_flex_order_func_nb`, but iterates using row-major order, with the rows changing fastest, and the columns/groups changing slowest. Call hierarchy: ```plaintext 1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args) 2. pre_row_out = pre_row_func_nb(RowContext, *pre_sim_out, *pre_row_args) 3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_row_out, *pre_segment_args) while col != -1: 4. if segment_mask: col, order = flex_order_func_nb(FlexOrderContext, *pre_segment_out, *flex_order_args) 5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args) ... 6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_row_out, *post_segment_args) ... 7. post_row_func_nb(RowContext, *pre_sim_out, *post_row_args) ... 8. post_sim_func_nb(SimulationContext, *post_sim_args) ``` Let's illustrate the same example as in `from_order_func_nb` but adapted for this function: ```pycon >>> @njit ... def pre_row_func_nb(c, order_value_out, call_seq_out): ... print('\\tbefore row', c.i) ... return (order_value_out, call_seq_out) >>> @njit ... def post_row_func_nb(c, order_value_out, call_seq_out): ... print('\\tafter row', c.i) ... return None >>> sim_out = vbt.pf_nb.from_flex_order_func_rw_nb( ... target_shape, ... group_lens, ... cash_sharing, ... segment_mask=segment_mask, ... pre_sim_func_nb=pre_sim_func_nb, ... post_sim_func_nb=post_sim_func_nb, ... pre_row_func_nb=pre_row_func_nb, ... post_row_func_nb=post_row_func_nb, ... pre_segment_func_nb=pre_segment_func_nb, ... pre_segment_args=(size, price, size_type, direction), ... post_segment_func_nb=post_segment_func_nb, ... flex_order_func_nb=flex_order_func_nb, ... flex_order_args=(size, price, size_type, direction, fees, fixed_fees, slippage), ... post_order_func_nb=post_order_func_nb ... ) ``` ![](/assets/images/api/from_flex_order_func_rw_nb.svg){: loading=lazy style="width:800px;" } """ check_group_lens_nb(group_lens, target_shape[1]) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) segment_mask_ = to_2d_array_nb(np.asarray(segment_mask)) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) bm_close_ = to_2d_array_nb(np.asarray(bm_close)) order_records, log_records = prepare_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_pos_info = prepare_last_pos_info_nb( target_shape, init_position=init_position_, init_price=init_price_, fill_pos_info=fill_pos_info, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full_like(last_position, 0.0) last_locked_cash = np.full_like(last_position, 0.0) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) # Call function before the simulation pre_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args) _sim_start = sim_start_.min() _sim_end = sim_end_.max() for i in range(_sim_start, _sim_end): # Call function before the row pre_row_ctx = RowContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, i=i, ) pre_row_out = pre_row_func_nb(pre_row_ctx, *pre_sim_out, *pre_row_args) for group in range(len(group_lens)): if i < sim_start_[group] or i >= sim_end_[group]: continue from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col if track_value: # Update valuation price using current open for col in range(from_col, to_col): _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group]) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col]) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? is_segment_active = flex_select_nb(segment_mask_, i, group) if call_pre_segment or is_segment_active: # Call function before the segment pre_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, ) pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_row_out, *pre_segment_args) # Add cash if cash_sharing: _cash_deposits = flex_select_nb(cash_deposits_, i, group) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits else: for col in range(from_col, to_col): _cash_deposits = flex_select_nb(cash_deposits_, i, col) last_cash[col] += _cash_deposits last_free_cash[col] += _cash_deposits last_cash_deposits[col] = _cash_deposits if track_value: # Update value and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? is_segment_active = flex_select_nb(segment_mask_, i, group) if is_segment_active: call_idx = -1 while True: call_idx += 1 # Generate the next order flex_order_ctx = FlexOrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, call_idx=call_idx, ) col, order = flex_order_func_nb(flex_order_ctx, *pre_segment_out, *flex_order_args) if col == -1: break if col < from_col or col >= to_col: raise ValueError("Column out of bounds of the group") if not track_value: if ( order.size_type == SizeType.Value or order.size_type == SizeType.TargetValue or order.size_type == SizeType.TargetPercent ): raise ValueError("Cannot use size type that depends on not tracked value") # Get current values position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] pos_info_now = last_pos_info[col] if cash_sharing: cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] cash_deposits_now = last_cash_deposits[group] else: cash_now = last_cash[col] free_cash_now = last_free_cash[col] value_now = last_value[col] return_now = last_return[col] cash_deposits_now = last_cash_deposits[col] # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash if track_value: val_price_now = new_exec_state.val_price value_now = new_exec_state.value if cash_sharing: return_now = returns_nb_.get_return_nb( prev_close_value[group], value_now - cash_deposits_now, ) else: return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now) # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_cash[group] = cash_now last_free_cash[group] = free_cash_now last_value[group] = value_now last_return[group] = return_now else: last_cash[col] = cash_now last_free_cash[col] = free_cash_now last_value[col] = value_now last_return[col] = return_now if track_value: if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_value[group] = value_now last_return[group] = return_now else: last_value[col] = value_now last_return[col] = return_now # Update position record if fill_pos_info: if order_result.status == OrderStatus.Filled: if order_counts[col] > 0: order_id = order_records["id"][order_counts[col] - 1, col] else: order_id = -1 update_pos_info_nb( pos_info_now, i, col, exec_state.position, position_now, order_result, order_id, ) # Post-order callback post_order_ctx = PostOrderContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, col=col, call_idx=call_idx, cash_before=exec_state.cash, position_before=exec_state.position, debt_before=exec_state.debt, locked_cash_before=exec_state.locked_cash, free_cash_before=exec_state.free_cash, val_price_before=exec_state.val_price, value_before=exec_state.value, order_result=order_result, cash_now=cash_now, position_now=position_now, debt_now=debt_now, locked_cash_now=locked_cash_now, free_cash_now=free_cash_now, val_price_now=val_price_now, value_now=value_now, return_now=return_now, pos_info_now=pos_info_now, ) post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args) # NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows # Add earnings in cash for col in range(from_col, to_col): _cash_earnings = flex_select_nb(cash_earnings_, i, col) if cash_sharing: last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings else: last_cash[col] += _cash_earnings last_free_cash[col] += _cash_earnings if track_value: # Update valuation price using current close for col in range(from_col, to_col): _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close # Update previous value, current value, and return if cash_sharing: last_value[group] = calc_group_value_nb( from_col, to_col, last_cash[group], last_position, last_val_price, ) last_return[group] = returns_nb_.get_return_nb( prev_close_value[group], last_value[group] - last_cash_deposits[group], ) prev_close_value[group] = last_value[group] else: for col in range(from_col, to_col): if last_position[col] == 0: last_value[col] = last_cash[col] else: last_value[col] = last_cash[col] + last_position[col] * last_val_price[col] last_return[col] = returns_nb_.get_return_nb( prev_close_value[col], last_value[col] - last_cash_deposits[col], ) prev_close_value[col] = last_value[col] # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Is this segment active? if call_post_segment or is_segment_active: # Call function after the segment post_seg_ctx = SegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, call_seq_now=None, ) post_segment_func_nb(post_seg_ctx, *pre_row_out, *post_segment_args) # Call function after the row post_row_ctx = RowContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, i=i, ) post_row_func_nb(post_row_ctx, *pre_sim_out, *post_row_args) sim_end_reached = True for group in range(len(group_lens)): if i < sim_end_[group] - 1: sim_end_reached = False break if sim_end_reached: break # Call function after the simulation post_sim_ctx = SimulationContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, call_seq=None, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, segment_mask=segment_mask_, call_pre_segment=call_pre_segment, call_post_segment=call_post_segment, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, bm_close=bm_close_, ffill_val_price=ffill_val_price, update_value=update_value, fill_pos_info=fill_pos_info, track_value=track_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, sim_start=sim_start_, sim_end=sim_end_, ) post_sim_func_nb(post_sim_ctx, *post_sim_args) sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_, cash_earnings=cash_earnings_, call_seq=None, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) # %
@register_jitted def set_val_price_nb(c: SegmentContext, val_price: tp.FlexArray2d, price: tp.FlexArray2d) -> None: """Override valuation price in a context. Allows specifying a valuation price of positive infinity (takes the current price) and negative infinity (takes the latest valuation price).""" for col in range(c.from_col, c.to_col): _val_price = select_from_col_nb(c, col, val_price) if np.isinf(_val_price): if _val_price > 0: _price = select_from_col_nb(c, col, price) if np.isinf(_price): if _price > 0: _price = select_from_col_nb(c, col, c.close) else: _price = select_from_col_nb(c, col, c.open) _val_price = _price else: _val_price = c.last_val_price[col] if not np.isnan(_val_price) or not c.ffill_val_price: c.last_val_price[col] = _val_price # % @register_jitted def def_pre_segment_func_nb( # % line.replace("def_pre_segment_func_nb", "pre_segment_func_nb") c: SegmentContext, val_price: tp.FlexArray2d, price: tp.FlexArray2d, size: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, auto_call_seq: bool, ) -> tp.Args: """Pre-segment function that overrides the valuation price and optionally sorts the call sequence.""" set_val_price_nb(c, val_price, price) if auto_call_seq: order_value_out = np.empty(c.group_len, dtype=float_) sort_call_seq_nb(c, size, size_type, direction, order_value_out) return () # % # % @register_jitted def def_order_func_nb( # % line.replace("def_order_func_nb", "order_func_nb") c: OrderContext, size: tp.FlexArray2d, price: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, fees: tp.FlexArray2d, fixed_fees: tp.FlexArray2d, slippage: tp.FlexArray2d, min_size: tp.FlexArray2d, max_size: tp.FlexArray2d, size_granularity: tp.FlexArray2d, leverage: tp.FlexArray2d, leverage_mode: tp.FlexArray2d, reject_prob: tp.FlexArray2d, price_area_vio_mode: tp.FlexArray2d, allow_partial: tp.FlexArray2d, raise_reject: tp.FlexArray2d, log: tp.FlexArray2d, ) -> tp.Tuple[int, Order]: """Order function that creates an order based on default information.""" return order_nb( size=select_nb(c, size), price=select_nb(c, price), size_type=select_nb(c, size_type), direction=select_nb(c, direction), fees=select_nb(c, fees), fixed_fees=select_nb(c, fixed_fees), slippage=select_nb(c, slippage), min_size=select_nb(c, min_size), max_size=select_nb(c, max_size), size_granularity=select_nb(c, size_granularity), leverage=select_nb(c, leverage), leverage_mode=select_nb(c, leverage_mode), reject_prob=select_nb(c, reject_prob), price_area_vio_mode=select_nb(c, price_area_vio_mode), allow_partial=select_nb(c, allow_partial), raise_reject=select_nb(c, raise_reject), log=select_nb(c, log), ) # % # % @register_jitted def def_flex_pre_segment_func_nb( # % line.replace("def_flex_pre_segment_func_nb", "pre_segment_func_nb") c: SegmentContext, val_price: tp.FlexArray2d, price: tp.FlexArray2d, size: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, auto_call_seq: bool, ) -> tp.Args: """Flexible pre-segment function that overrides the valuation price and optionally sorts the call sequence.""" set_val_price_nb(c, val_price, price) call_seq_out = np.arange(c.group_len) if auto_call_seq: order_value_out = np.empty(c.group_len, dtype=float_) sort_call_seq_out_nb(c, size, size_type, direction, order_value_out, call_seq_out) return (call_seq_out,) # % # % @register_jitted def def_flex_order_func_nb( # % line.replace("def_flex_order_func_nb", "flex_order_func_nb") c: FlexOrderContext, call_seq_now: tp.Array1d, size: tp.FlexArray2d, price: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, fees: tp.FlexArray2d, fixed_fees: tp.FlexArray2d, slippage: tp.FlexArray2d, min_size: tp.FlexArray2d, max_size: tp.FlexArray2d, size_granularity: tp.FlexArray2d, leverage: tp.FlexArray2d, leverage_mode: tp.FlexArray2d, reject_prob: tp.FlexArray2d, price_area_vio_mode: tp.FlexArray2d, allow_partial: tp.FlexArray2d, raise_reject: tp.FlexArray2d, log: tp.FlexArray2d, ) -> tp.Tuple[int, Order]: """Flexible order function that creates an order based on default information.""" if c.call_idx < c.group_len: col = c.from_col + call_seq_now[c.call_idx] order = order_nb( size=select_from_col_nb(c, col, size), price=select_from_col_nb(c, col, price), size_type=select_from_col_nb(c, col, size_type), direction=select_from_col_nb(c, col, direction), fees=select_from_col_nb(c, col, fees), fixed_fees=select_from_col_nb(c, col, fixed_fees), slippage=select_from_col_nb(c, col, slippage), min_size=select_from_col_nb(c, col, min_size), max_size=select_from_col_nb(c, col, max_size), size_granularity=select_from_col_nb(c, col, size_granularity), leverage=select_from_col_nb(c, col, leverage), leverage_mode=select_from_col_nb(c, col, leverage_mode), reject_prob=select_from_col_nb(c, col, reject_prob), price_area_vio_mode=select_from_col_nb(c, col, price_area_vio_mode), allow_partial=select_from_col_nb(c, col, allow_partial), raise_reject=select_from_col_nb(c, col, raise_reject), log=select_from_col_nb(c, col, log), ) return col, order return -1, order_nothing_nb() # %
# ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for portfolio simulation based on orders.""" from numba import prange from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.portfolio import chunking as portfolio_ch from vectorbtpro.portfolio.nb.core import * from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.returns.nb import get_return_nb from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.array_ import insert_argsort_nb @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, init_cash=base_ch.FlexArraySlicer(), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=base_ch.FlexArraySlicer(axis=1), cash_earnings=base_ch.flex_array_gl_slicer, cash_dividends=base_ch.flex_array_gl_slicer, size=base_ch.flex_array_gl_slicer, price=base_ch.flex_array_gl_slicer, size_type=base_ch.flex_array_gl_slicer, direction=base_ch.flex_array_gl_slicer, fees=base_ch.flex_array_gl_slicer, fixed_fees=base_ch.flex_array_gl_slicer, slippage=base_ch.flex_array_gl_slicer, min_size=base_ch.flex_array_gl_slicer, max_size=base_ch.flex_array_gl_slicer, size_granularity=base_ch.flex_array_gl_slicer, leverage=base_ch.flex_array_gl_slicer, leverage_mode=base_ch.flex_array_gl_slicer, reject_prob=base_ch.flex_array_gl_slicer, price_area_vio_mode=base_ch.flex_array_gl_slicer, allow_partial=base_ch.flex_array_gl_slicer, raise_reject=base_ch.flex_array_gl_slicer, log=base_ch.flex_array_gl_slicer, val_price=base_ch.flex_array_gl_slicer, from_ago=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), call_seq=base_ch.array_gl_slicer, auto_call_seq=None, ffill_val_price=None, update_value=None, save_state=None, save_value=None, save_returns=None, skip_empty=None, max_order_records=None, max_log_records=None, ), **portfolio_ch.merge_sim_outs_config, ) @register_jitted(cache=True, tags={"can_parallel"}) def from_orders_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, cash_dividends: tp.FlexArray2dLike = 0.0, size: tp.FlexArray2dLike = np.inf, price: tp.FlexArray2dLike = np.inf, size_type: tp.FlexArray2dLike = SizeType.Amount, direction: tp.FlexArray2dLike = Direction.Both, fees: tp.FlexArray2dLike = 0.0, fixed_fees: tp.FlexArray2dLike = 0.0, slippage: tp.FlexArray2dLike = 0.0, min_size: tp.FlexArray2dLike = np.nan, max_size: tp.FlexArray2dLike = np.nan, size_granularity: tp.FlexArray2dLike = np.nan, leverage: tp.FlexArray2dLike = 1.0, leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy, reject_prob: tp.FlexArray2dLike = 0.0, price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore, allow_partial: tp.FlexArray2dLike = True, raise_reject: tp.FlexArray2dLike = False, log: tp.FlexArray2dLike = False, val_price: tp.FlexArray2dLike = np.inf, from_ago: tp.FlexArray2dLike = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, call_seq: tp.Optional[tp.Array2d] = None, auto_call_seq: bool = False, ffill_val_price: bool = True, update_value: bool = False, save_state: bool = False, save_value: bool = False, save_returns: bool = False, skip_empty: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, ) -> SimulationOutput: """Creates on order out of each element. Iterates in the column-major order. Utilizes flexible broadcasting. !!! note Should be only grouped if cash sharing is enabled. Single value must be passed as a 0-dim array (for example, by using `np.asarray(value)`). Usage: * Buy and hold using all cash and closing price (default): ```pycon >>> from vectorbtpro import * >>> from vectorbtpro.records.nb import col_map_nb >>> from vectorbtpro.portfolio.nb import from_orders_nb, asset_flow_nb >>> close = np.array([1, 2, 3, 4, 5])[:, None] >>> sim_out = from_orders_nb( ... target_shape=close.shape, ... group_lens=np.array([1]), ... call_seq=np.full(close.shape, 0), ... close=close ... ) >>> col_map = col_map_nb(sim_out.order_records['col'], close.shape[1]) >>> asset_flow = asset_flow_nb(close.shape, sim_out.order_records, col_map) >>> asset_flow array([[100.], [ 0.], [ 0.], [ 0.], [ 0.]]) ``` """ check_group_lens_nb(group_lens, target_shape[1]) cash_sharing = is_grouped_nb(group_lens) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends)) size_ = to_2d_array_nb(np.asarray(size)) price_ = to_2d_array_nb(np.asarray(price)) size_type_ = to_2d_array_nb(np.asarray(size_type)) direction_ = to_2d_array_nb(np.asarray(direction)) fees_ = to_2d_array_nb(np.asarray(fees)) fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees)) slippage_ = to_2d_array_nb(np.asarray(slippage)) min_size_ = to_2d_array_nb(np.asarray(min_size)) max_size_ = to_2d_array_nb(np.asarray(max_size)) size_granularity_ = to_2d_array_nb(np.asarray(size_granularity)) leverage_ = to_2d_array_nb(np.asarray(leverage)) leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode)) reject_prob_ = to_2d_array_nb(np.asarray(reject_prob)) price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode)) allow_partial_ = to_2d_array_nb(np.asarray(allow_partial)) raise_reject_ = to_2d_array_nb(np.asarray(raise_reject)) log_ = to_2d_array_nb(np.asarray(log)) val_price_ = to_2d_array_nb(np.asarray(val_price)) from_ago_ = to_2d_array_nb(np.asarray(from_ago)) order_records, log_records = prepare_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full(target_shape[1], 0.0, dtype=float_) last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1 if track_cash_deposits: cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_) else: cash_deposits_out = np.full((1, 1), 0.0, dtype=float_) track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1 track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1 track_cash_earnings = track_cash_earnings or track_cash_dividends if track_cash_earnings: cash_earnings_out = np.full(target_shape, 0.0, dtype=float_) else: cash_earnings_out = np.full((1, 1), 0.0, dtype=float_) if save_state: cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) position = np.full(target_shape, np.nan, dtype=float_) debt = np.full(target_shape, np.nan, dtype=float_) locked_cash = np.full(target_shape, np.nan, dtype=float_) free_cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: cash = np.full((0, 0), np.nan, dtype=float_) position = np.full((0, 0), np.nan, dtype=float_) debt = np.full((0, 0), np.nan, dtype=float_) locked_cash = np.full((0, 0), np.nan, dtype=float_) free_cash = np.full((0, 0), np.nan, dtype=float_) if save_value: value = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: value = np.full((0, 0), np.nan, dtype=float_) if save_returns: returns = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: returns = np.full((0, 0), np.nan, dtype=float_) in_outputs = FOInOutputs( cash=cash, position=position, debt=debt, locked_cash=locked_cash, free_cash=free_cash, value=value, returns=returns, ) temp_call_seq = np.empty(target_shape[1], dtype=int_) temp_order_value = np.empty(target_shape[1], dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col _sim_start = sim_start_[group] _sim_end = sim_end_[group] for i in range(_sim_start, _sim_end): # Add cash _cash_deposits = flex_select_nb(cash_deposits_, i, group) if _cash_deposits < 0: _cash_deposits = max(_cash_deposits, -last_cash[group]) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits if track_cash_deposits: cash_deposits_out[i, group] += _cash_deposits skip = skip_empty if skip: for c in range(group_len): col = from_col + c _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: continue if flex_select_nb(log_, i, col): skip = False break if not np.isnan(flex_select_nb(size_, _i, col)): if not np.isnan(flex_select_nb(price_, _i, col)): skip = False break if not skip or ffill_val_price: for c in range(group_len): col = from_col + c # Update valuation price using current open _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Resolve valuation price _val_price = flex_select_nb(val_price_, i, col) if np.isinf(_val_price): if _val_price > 0: _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: _price = np.nan else: _price = flex_select_nb(price_, _i, col) if np.isinf(_price): if _price > 0: _price = flex_select_nb(close_, i, col) else: _price = _open _val_price = _price else: _val_price = last_val_price[col] if not np.isnan(_val_price) or not ffill_val_price: last_val_price[col] = _val_price if not skip: # Update value and return group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - _cash_deposits, ) if cash_sharing: # Dynamically sort by order value -> selling comes first to release funds early if call_seq is None: for c in range(group_len): temp_call_seq[c] = c call_seq_now = temp_call_seq[:group_len] else: call_seq_now = call_seq[i, from_col:to_col] if auto_call_seq: # Same as sort_by_order_value_ctx_nb but with flexible indexing for c in range(group_len): col = from_col + c exec_state = ExecState( cash=last_cash[group] if cash_sharing else last_cash[col], position=last_position[col], debt=last_debt[col], locked_cash=last_locked_cash[col], free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col], val_price=last_val_price[col], value=last_value[group] if cash_sharing else last_value[col], ) _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: temp_order_value[c] = 0.0 else: temp_order_value[c] = approx_order_value_nb( exec_state, flex_select_nb(size_, _i, col), flex_select_nb(size_type_, _i, col), flex_select_nb(direction_, _i, col), ) if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") # Sort by order value insert_argsort_nb(temp_order_value[:group_len], call_seq_now) for k in range(group_len): if cash_sharing: c = call_seq_now[k] if c >= group_len: raise ValueError("Call index out of bounds of the group") else: c = k col = from_col + c # Get current values per column position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] # Generate the next order _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: continue order = order_nb( size=flex_select_nb(size_, _i, col), price=flex_select_nb(price_, _i, col), size_type=flex_select_nb(size_type_, _i, col), direction=flex_select_nb(direction_, _i, col), fees=flex_select_nb(fees_, _i, col), fixed_fees=flex_select_nb(fixed_fees_, _i, col), slippage=flex_select_nb(slippage_, _i, col), min_size=flex_select_nb(min_size_, _i, col), max_size=flex_select_nb(max_size_, _i, col), size_granularity=flex_select_nb(size_granularity_, _i, col), leverage=flex_select_nb(leverage_, _i, col), leverage_mode=flex_select_nb(leverage_mode_, _i, col), reject_prob=flex_select_nb(reject_prob_, _i, col), price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col), allow_partial=flex_select_nb(allow_partial_, _i, col), raise_reject=flex_select_nb(raise_reject_, _i, col), log=flex_select_nb(log_, _i, col), ) # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash val_price_now = new_exec_state.val_price value_now = new_exec_state.value # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now last_cash[group] = cash_now last_free_cash[group] = free_cash_now last_value[group] = value_now last_return[group] = return_now for col in range(from_col, to_col): # Update valuation price using current close _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close _cash_earnings = flex_select_nb(cash_earnings_, i, col) _cash_dividends = flex_select_nb(cash_dividends_, i, col) _cash_earnings += _cash_dividends * last_position[col] last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings if track_cash_earnings: cash_earnings_out[i, col] += _cash_earnings if save_state: position[i, col] = last_position[col] debt[i, col] = last_debt[col] locked_cash[i, col] = last_locked_cash[col] cash[i, group] = last_cash[group] free_cash[i, group] = last_free_cash[group] # Update value and return group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - _cash_deposits, ) prev_close_value[group] = last_value[group] if save_value: in_outputs.value[i, group] = last_value[group] if save_returns: in_outputs.returns[i, group] = last_return[group] sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_out, cash_earnings=cash_earnings_out, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for portfolio simulation based on signals.""" from numba import prange from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.generic.enums import BarZone from vectorbtpro.portfolio import chunking as portfolio_ch from vectorbtpro.portfolio.nb.core import * from vectorbtpro.portfolio.nb.from_order_func import no_post_func_nb from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.returns.nb import get_return_nb from vectorbtpro.signals.enums import StopType from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.array_ import insert_argsort_nb from vectorbtpro.utils.math_ import is_less_nb from vectorbtpro.utils.template import RepFunc @register_jitted(cache=True) def resolve_pending_conflict_nb( is_pending_long: bool, is_user_long: bool, upon_adj_conflict: int, upon_opp_conflict: int, ) -> tp.Tuple[bool, bool]: """Resolve any conflict between a pending signal and a user-defined signal. Returns whether to keep the pending signal and execute the user signal.""" if (is_pending_long and is_user_long) or (not is_pending_long and not is_user_long): if upon_adj_conflict == PendingConflictMode.KeepIgnore: return True, False if upon_adj_conflict == PendingConflictMode.KeepExecute: return True, True if upon_adj_conflict == PendingConflictMode.CancelIgnore: return False, False if upon_adj_conflict == PendingConflictMode.CancelExecute: return False, True raise ValueError("Invalid PendingConflictMode option") if upon_opp_conflict == PendingConflictMode.KeepIgnore: return True, False if upon_opp_conflict == PendingConflictMode.KeepExecute: return True, True if upon_opp_conflict == PendingConflictMode.CancelIgnore: return False, False if upon_opp_conflict == PendingConflictMode.CancelExecute: return False, True raise ValueError("Invalid PendingConflictMode option") @register_jitted(cache=True) def generate_stop_signal_nb( position_now: float, stop_exit_type: int, accumulate: int = False, ) -> tp.Tuple[bool, bool, bool, bool, int]: """Generate stop signal and change accumulation if needed.""" is_long_entry = False is_long_exit = False is_short_entry = False is_short_exit = False if position_now > 0: if stop_exit_type == StopExitType.Close: is_long_exit = True accumulate = AccumulationMode.Disabled elif stop_exit_type == StopExitType.CloseReduce: is_long_exit = True elif stop_exit_type == StopExitType.Reverse: is_short_entry = True accumulate = AccumulationMode.Disabled elif stop_exit_type == StopExitType.ReverseReduce: is_short_entry = True else: raise ValueError("Invalid StopExitType option") elif position_now < 0: if stop_exit_type == StopExitType.Close: is_short_exit = True accumulate = AccumulationMode.Disabled elif stop_exit_type == StopExitType.CloseReduce: is_short_exit = True elif stop_exit_type == StopExitType.Reverse: is_long_entry = True accumulate = AccumulationMode.Disabled elif stop_exit_type == StopExitType.ReverseReduce: is_long_entry = True else: raise ValueError("Invalid StopExitType option") return is_long_entry, is_long_exit, is_short_entry, is_short_exit, accumulate @register_jitted(cache=True) def resolve_signal_conflict_nb( position_now: float, is_entry: bool, is_exit: bool, direction: int, conflict_mode: int, ) -> tp.Tuple[bool, bool]: """Resolve any conflict between an entry and an exit.""" if is_entry and is_exit: # Conflict if conflict_mode == ConflictMode.Entry: # Ignore exit signal is_exit = False elif conflict_mode == ConflictMode.Exit: # Ignore entry signal is_entry = False elif conflict_mode == ConflictMode.Adjacent: # Take the signal adjacent to the position we are in if position_now == 0: # Cannot decide -> ignore is_entry = False is_exit = False else: if direction == Direction.Both: if position_now > 0: is_exit = False elif position_now < 0: is_entry = False else: is_exit = False elif conflict_mode == ConflictMode.Opposite: # Take the signal opposite to the position we are in if position_now == 0: if direction == Direction.Both: # Cannot decide -> ignore is_entry = False is_exit = False else: is_exit = False else: if direction == Direction.Both: if position_now > 0: is_entry = False elif position_now < 0: is_exit = False else: is_entry = False elif conflict_mode == ConflictMode.Ignore: is_entry = False is_exit = False else: raise ValueError("Invalid ConflictMode option") return is_entry, is_exit @register_jitted(cache=True) def resolve_dir_conflict_nb( position_now: float, is_long_entry: bool, is_short_entry: bool, upon_dir_conflict: int, ) -> tp.Tuple[bool, bool]: """Resolve any direction conflict between a long entry and a short entry.""" if is_long_entry and is_short_entry: if upon_dir_conflict == DirectionConflictMode.Long: is_short_entry = False elif upon_dir_conflict == DirectionConflictMode.Short: is_long_entry = False elif upon_dir_conflict == DirectionConflictMode.Adjacent: if position_now > 0: is_short_entry = False elif position_now < 0: is_long_entry = False else: is_long_entry = False is_short_entry = False elif upon_dir_conflict == DirectionConflictMode.Opposite: if position_now > 0: is_long_entry = False elif position_now < 0: is_short_entry = False else: is_long_entry = False is_short_entry = False elif upon_dir_conflict == DirectionConflictMode.Ignore: is_long_entry = False is_short_entry = False else: raise ValueError("Invalid DirectionConflictMode option") return is_long_entry, is_short_entry @register_jitted(cache=True) def resolve_opposite_entry_nb( position_now: float, is_long_entry: bool, is_long_exit: bool, is_short_entry: bool, is_short_exit: bool, upon_opposite_entry: int, accumulate: int, ) -> tp.Tuple[bool, bool, bool, bool, int]: """Resolve opposite entry.""" if position_now > 0 and is_short_entry: if upon_opposite_entry == OppositeEntryMode.Ignore: is_short_entry = False elif upon_opposite_entry == OppositeEntryMode.Close: is_short_entry = False is_long_exit = True accumulate = AccumulationMode.Disabled elif upon_opposite_entry == OppositeEntryMode.CloseReduce: is_short_entry = False is_long_exit = True elif upon_opposite_entry == OppositeEntryMode.Reverse: accumulate = AccumulationMode.Disabled elif upon_opposite_entry == OppositeEntryMode.ReverseReduce: pass else: raise ValueError("Invalid OppositeEntryMode option") if position_now < 0 and is_long_entry: if upon_opposite_entry == OppositeEntryMode.Ignore: is_long_entry = False elif upon_opposite_entry == OppositeEntryMode.Close: is_long_entry = False is_short_exit = True accumulate = AccumulationMode.Disabled elif upon_opposite_entry == OppositeEntryMode.CloseReduce: is_long_entry = False is_short_exit = True elif upon_opposite_entry == OppositeEntryMode.Reverse: accumulate = AccumulationMode.Disabled elif upon_opposite_entry == OppositeEntryMode.ReverseReduce: pass else: raise ValueError("Invalid OppositeEntryMode option") return is_long_entry, is_long_exit, is_short_entry, is_short_exit, accumulate @register_jitted(cache=True) def signal_to_size_nb( position_now: float, val_price_now: float, value_now: float, is_long_entry: bool, is_long_exit: bool, is_short_entry: bool, is_short_exit: bool, size: float, size_type: int, accumulate: int, ) -> tp.Tuple[float, int, int]: """Translate direction-aware signals into size, size type, and direction.""" if ( accumulate != AccumulationMode.Disabled and accumulate != AccumulationMode.Both and accumulate != AccumulationMode.AddOnly and accumulate != AccumulationMode.RemoveOnly ): raise ValueError("Invalid AccumulationMode option") def _check_size_type(_size_type): if ( _size_type == SizeType.TargetAmount or _size_type == SizeType.TargetValue or _size_type == SizeType.TargetPercent or _size_type == SizeType.TargetPercent100 ): raise ValueError("Target size types are not supported") if is_less_nb(size, 0): raise ValueError("Negative size is not allowed. Please express direction using signals.") if size_type == SizeType.Percent100: size /= 100 size_type = SizeType.Percent if size_type == SizeType.ValuePercent100: size /= 100 size_type = SizeType.ValuePercent if size_type == SizeType.ValuePercent: size *= value_now size_type = SizeType.Value order_size = np.nan direction = Direction.Both abs_position_now = abs(position_now) if position_now > 0: # We're in a long position if is_short_entry: _check_size_type(size_type) if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly: # Decrease the position order_size = -size else: # Reverse the position if not np.isnan(size): if size_type == SizeType.Percent: order_size = -size else: order_size = -abs_position_now if size_type == SizeType.Value: order_size -= size / val_price_now else: order_size -= size size_type = SizeType.Amount elif is_long_exit: direction = Direction.LongOnly if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly: # Decrease the position _check_size_type(size_type) order_size = -size else: # Close the position order_size = -abs_position_now size_type = SizeType.Amount elif is_long_entry: _check_size_type(size_type) direction = Direction.LongOnly if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.AddOnly: # Increase the position order_size = size elif position_now < 0: # We're in a short position if is_long_entry: _check_size_type(size_type) if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly: # Decrease the position order_size = size else: # Reverse the position if not np.isnan(size): if size_type == SizeType.Percent: order_size = size else: order_size = abs_position_now if size_type == SizeType.Value: order_size += size / val_price_now else: order_size += size size_type = SizeType.Amount elif is_short_exit: direction = Direction.ShortOnly if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly: # Decrease the position _check_size_type(size_type) order_size = size else: # Close the position order_size = abs_position_now size_type = SizeType.Amount elif is_short_entry: _check_size_type(size_type) direction = Direction.ShortOnly if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.AddOnly: # Increase the position order_size = -size else: _check_size_type(size_type) if is_long_entry: # Open long position order_size = size elif is_short_entry: # Open short position order_size = -size if direction == Direction.ShortOnly: order_size = -order_size return order_size, size_type, direction @register_jitted(cache=True) def prepare_fs_records_nb( target_shape: tp.Shape, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, ) -> tp.Tuple[tp.RecordArray2d, tp.RecordArray2d]: """Prepare from-signals records.""" if max_order_records is None: order_records = np.empty((target_shape[0], target_shape[1]), dtype=fs_order_dt) else: order_records = np.empty((max_order_records, target_shape[1]), dtype=fs_order_dt) if max_log_records is None: log_records = np.empty((target_shape[0], target_shape[1]), dtype=log_dt) else: log_records = np.empty((max_log_records, target_shape[1]), dtype=log_dt) return order_records, log_records @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, init_cash=base_ch.FlexArraySlicer(), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=base_ch.FlexArraySlicer(axis=1), cash_earnings=base_ch.flex_array_gl_slicer, cash_dividends=base_ch.flex_array_gl_slicer, long_entries=base_ch.flex_array_gl_slicer, long_exits=base_ch.flex_array_gl_slicer, short_entries=base_ch.flex_array_gl_slicer, short_exits=base_ch.flex_array_gl_slicer, size=base_ch.flex_array_gl_slicer, price=base_ch.flex_array_gl_slicer, size_type=base_ch.flex_array_gl_slicer, fees=base_ch.flex_array_gl_slicer, fixed_fees=base_ch.flex_array_gl_slicer, slippage=base_ch.flex_array_gl_slicer, min_size=base_ch.flex_array_gl_slicer, max_size=base_ch.flex_array_gl_slicer, size_granularity=base_ch.flex_array_gl_slicer, leverage=base_ch.flex_array_gl_slicer, leverage_mode=base_ch.flex_array_gl_slicer, reject_prob=base_ch.flex_array_gl_slicer, price_area_vio_mode=base_ch.flex_array_gl_slicer, allow_partial=base_ch.flex_array_gl_slicer, raise_reject=base_ch.flex_array_gl_slicer, log=base_ch.flex_array_gl_slicer, val_price=base_ch.flex_array_gl_slicer, accumulate=base_ch.flex_array_gl_slicer, upon_long_conflict=base_ch.flex_array_gl_slicer, upon_short_conflict=base_ch.flex_array_gl_slicer, upon_dir_conflict=base_ch.flex_array_gl_slicer, upon_opposite_entry=base_ch.flex_array_gl_slicer, from_ago=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), call_seq=base_ch.array_gl_slicer, auto_call_seq=None, ffill_val_price=None, update_value=None, save_state=None, save_value=None, save_returns=None, skip_empty=None, max_order_records=None, max_log_records=None, ), **portfolio_ch.merge_sim_outs_config, ) @register_jitted(cache=True, tags={"can_parallel"}) def from_basic_signals_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, cash_dividends: tp.FlexArray2dLike = 0.0, long_entries: tp.FlexArray2dLike = False, long_exits: tp.FlexArray2dLike = False, short_entries: tp.FlexArray2dLike = False, short_exits: tp.FlexArray2dLike = False, size: tp.FlexArray2dLike = np.inf, price: tp.FlexArray2dLike = np.inf, size_type: tp.FlexArray2dLike = SizeType.Amount, fees: tp.FlexArray2dLike = 0.0, fixed_fees: tp.FlexArray2dLike = 0.0, slippage: tp.FlexArray2dLike = 0.0, min_size: tp.FlexArray2dLike = np.nan, max_size: tp.FlexArray2dLike = np.nan, size_granularity: tp.FlexArray2dLike = np.nan, leverage: tp.FlexArray2dLike = 1.0, leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy, reject_prob: tp.FlexArray2dLike = 0.0, price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore, allow_partial: tp.FlexArray2dLike = True, raise_reject: tp.FlexArray2dLike = False, log: tp.FlexArray2dLike = False, val_price: tp.FlexArray2dLike = np.inf, accumulate: tp.FlexArray2dLike = AccumulationMode.Disabled, upon_long_conflict: tp.FlexArray2dLike = ConflictMode.Ignore, upon_short_conflict: tp.FlexArray2dLike = ConflictMode.Ignore, upon_dir_conflict: tp.FlexArray2dLike = DirectionConflictMode.Ignore, upon_opposite_entry: tp.FlexArray2dLike = OppositeEntryMode.ReverseReduce, from_ago: tp.FlexArray2dLike = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, call_seq: tp.Optional[tp.Array2d] = None, auto_call_seq: bool = False, ffill_val_price: bool = True, update_value: bool = False, save_state: bool = False, save_value: bool = False, save_returns: bool = False, skip_empty: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, ) -> SimulationOutput: """Simulate given basic signals (no limit or stop orders). Iterates in the column-major order. Utilizes flexible broadcasting. !!! note Should be only grouped if cash sharing is enabled. """ check_group_lens_nb(group_lens, target_shape[1]) cash_sharing = is_grouped_nb(group_lens) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends)) long_entries_ = to_2d_array_nb(np.asarray(long_entries)) long_exits_ = to_2d_array_nb(np.asarray(long_exits)) short_entries_ = to_2d_array_nb(np.asarray(short_entries)) short_exits_ = to_2d_array_nb(np.asarray(short_exits)) size_ = to_2d_array_nb(np.asarray(size)) price_ = to_2d_array_nb(np.asarray(price)) size_type_ = to_2d_array_nb(np.asarray(size_type)) fees_ = to_2d_array_nb(np.asarray(fees)) fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees)) slippage_ = to_2d_array_nb(np.asarray(slippage)) min_size_ = to_2d_array_nb(np.asarray(min_size)) max_size_ = to_2d_array_nb(np.asarray(max_size)) size_granularity_ = to_2d_array_nb(np.asarray(size_granularity)) leverage_ = to_2d_array_nb(np.asarray(leverage)) leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode)) reject_prob_ = to_2d_array_nb(np.asarray(reject_prob)) price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode)) allow_partial_ = to_2d_array_nb(np.asarray(allow_partial)) raise_reject_ = to_2d_array_nb(np.asarray(raise_reject)) log_ = to_2d_array_nb(np.asarray(log)) val_price_ = to_2d_array_nb(np.asarray(val_price)) accumulate_ = to_2d_array_nb(np.asarray(accumulate)) upon_long_conflict_ = to_2d_array_nb(np.asarray(upon_long_conflict)) upon_short_conflict_ = to_2d_array_nb(np.asarray(upon_short_conflict)) upon_dir_conflict_ = to_2d_array_nb(np.asarray(upon_dir_conflict)) upon_opposite_entry_ = to_2d_array_nb(np.asarray(upon_opposite_entry)) from_ago_ = to_2d_array_nb(np.asarray(from_ago)) order_records, log_records = prepare_fs_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full(target_shape[1], 0.0, dtype=float_) last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1 if track_cash_deposits: cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_) else: cash_deposits_out = np.full((1, 1), 0.0, dtype=float_) track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1 track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1 track_cash_earnings = track_cash_earnings or track_cash_dividends if track_cash_earnings: cash_earnings_out = np.full(target_shape, 0.0, dtype=float_) else: cash_earnings_out = np.full((1, 1), 0.0, dtype=float_) if save_state: cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) position = np.full(target_shape, np.nan, dtype=float_) debt = np.full(target_shape, np.nan, dtype=float_) locked_cash = np.full(target_shape, np.nan, dtype=float_) free_cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: cash = np.full((0, 0), np.nan, dtype=float_) position = np.full((0, 0), np.nan, dtype=float_) debt = np.full((0, 0), np.nan, dtype=float_) locked_cash = np.full((0, 0), np.nan, dtype=float_) free_cash = np.full((0, 0), np.nan, dtype=float_) if save_value: value = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: value = np.full((0, 0), np.nan, dtype=float_) if save_returns: returns = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: returns = np.full((0, 0), np.nan, dtype=float_) in_outputs = FSInOutputs( cash=cash, position=position, debt=debt, locked_cash=locked_cash, free_cash=free_cash, value=value, returns=returns, ) last_signal = np.empty(target_shape[1], dtype=int_) main_info = np.empty(target_shape[1], dtype=main_info_dt) temp_call_seq = np.empty(target_shape[1], dtype=int_) temp_sort_by = np.empty(target_shape[1], dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col _sim_start = sim_start_[group] _sim_end = sim_end_[group] for i in range(_sim_start, _sim_end): # Add cash _cash_deposits = flex_select_nb(cash_deposits_, i, group) if _cash_deposits < 0: _cash_deposits = max(_cash_deposits, -last_cash[group]) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits if track_cash_deposits: cash_deposits_out[i, group] += _cash_deposits for c in range(group_len): col = from_col + c # Update valuation price using current open _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Resolve valuation price _val_price = flex_select_nb(val_price_, i, col) if np.isinf(_val_price): if _val_price > 0: _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: _price = np.nan else: _price = flex_select_nb(price_, _i, col) if np.isinf(_price): if _price > 0: _price = flex_select_nb(close_, i, col) else: _price = _open _val_price = _price else: _val_price = last_val_price[col] if not np.isnan(_val_price) or not ffill_val_price: last_val_price[col] = _val_price # Update value and return group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - _cash_deposits, ) # Get signals skip = skip_empty for c in range(group_len): col = from_col + c if flex_select_nb(log_, i, col): skip = False _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: is_long_entry = False is_long_exit = False is_short_entry = False is_short_exit = False else: is_long_entry = flex_select_nb(long_entries_, _i, col) is_long_exit = flex_select_nb(long_exits_, _i, col) is_short_entry = flex_select_nb(short_entries_, _i, col) is_short_exit = flex_select_nb(short_exits_, _i, col) # Pack signals into a single integer last_signal[col] = ( (is_long_entry << 4) | (is_long_exit << 3) | (is_short_entry << 2) | (is_short_exit << 1) ) if last_signal[col] > 0: skip = False if not skip: # Get size and value of each order for c in range(group_len): col = from_col + c # Set defaults main_info["bar_zone"][col] = -1 main_info["signal_idx"][col] = -1 main_info["creation_idx"][col] = -1 main_info["idx"][col] = i main_info["val_price"][col] = np.nan main_info["price"][col] = np.nan main_info["size"][col] = np.nan main_info["size_type"][col] = -1 main_info["direction"][col] = -1 main_info["type"][col] = OrderType.Market main_info["stop_type"][col] = -1 temp_sort_by[col] = 0.0 # Unpack a single integer into signals is_long_entry = (last_signal[col] >> 4) & 1 is_long_exit = (last_signal[col] >> 3) & 1 is_short_entry = (last_signal[col] >> 2) & 1 is_short_exit = (last_signal[col] >> 1) & 1 # Resolve the current bar _i = i - abs(flex_select_nb(from_ago_, i, col)) _open = flex_select_nb(open_, i, col) _high = flex_select_nb(high_, i, col) _low = flex_select_nb(low_, i, col) _close = flex_select_nb(close_, i, col) _high, _low = resolve_hl_nb( open=_open, high=_high, low=_low, close=_close, ) # Process user signal if _i >= 0: _accumulate = flex_select_nb(accumulate_, _i, col) if is_long_entry or is_short_entry: # Resolve any single-direction conflicts _upon_long_conflict = flex_select_nb(upon_long_conflict_, _i, col) is_long_entry, is_long_exit = resolve_signal_conflict_nb( position_now=last_position[col], is_entry=is_long_entry, is_exit=is_long_exit, direction=Direction.LongOnly, conflict_mode=_upon_long_conflict, ) _upon_short_conflict = flex_select_nb(upon_short_conflict_, _i, col) is_short_entry, is_short_exit = resolve_signal_conflict_nb( position_now=last_position[col], is_entry=is_short_entry, is_exit=is_short_exit, direction=Direction.ShortOnly, conflict_mode=_upon_short_conflict, ) # Resolve any multi-direction conflicts _upon_dir_conflict = flex_select_nb(upon_dir_conflict_, _i, col) is_long_entry, is_short_entry = resolve_dir_conflict_nb( position_now=last_position[col], is_long_entry=is_long_entry, is_short_entry=is_short_entry, upon_dir_conflict=_upon_dir_conflict, ) # Resolve an opposite entry _upon_opposite_entry = flex_select_nb(upon_opposite_entry_, _i, col) ( is_long_entry, is_long_exit, is_short_entry, is_short_exit, _accumulate, ) = resolve_opposite_entry_nb( position_now=last_position[col], is_long_entry=is_long_entry, is_long_exit=is_long_exit, is_short_entry=is_short_entry, is_short_exit=is_short_exit, upon_opposite_entry=_upon_opposite_entry, accumulate=_accumulate, ) # Resolve the price _price = flex_select_nb(price_, _i, col) # Convert both signals to size (direction-aware), size type, and direction _size, _size_type, _direction = signal_to_size_nb( position_now=last_position[col], val_price_now=last_val_price[col], value_now=last_value[group], is_long_entry=is_long_entry, is_long_exit=is_long_exit, is_short_entry=is_short_entry, is_short_exit=is_short_exit, size=flex_select_nb(size_, _i, col), size_type=flex_select_nb(size_type_, _i, col), accumulate=_accumulate, ) # Execute user signal if np.isinf(_price): if _price > 0: main_info["bar_zone"][col] = BarZone.Close else: main_info["bar_zone"][col] = BarZone.Open else: main_info["bar_zone"][col] = BarZone.Middle main_info["signal_idx"][col] = _i main_info["creation_idx"][col] = i main_info["idx"][col] = _i main_info["val_price"][col] = last_val_price[col] main_info["price"][col] = _price main_info["size"][col] = _size main_info["size_type"][col] = _size_type main_info["direction"][col] = _direction skip = skip_empty if skip: for col in range(from_col, to_col): if flex_select_nb(log_, i, col): skip = False break if not np.isnan(main_info["size"][col]): skip = False break if not skip: # Check bar zone and update valuation price bar_zone = -1 same_bar_zone = True same_timing = True for c in range(group_len): col = from_col + c if np.isnan(main_info["size"][col]): continue if bar_zone == -1: bar_zone = main_info["bar_zone"][col] if main_info["bar_zone"][col] != bar_zone: same_bar_zone = False same_timing = False if main_info["bar_zone"][col] == BarZone.Middle: same_timing = False _val_price = main_info["val_price"][col] if not np.isnan(_val_price) or not ffill_val_price: last_val_price[col] = _val_price if cash_sharing: # Dynamically sort by order value -> selling comes first to release funds early if call_seq is None: for c in range(group_len): temp_call_seq[c] = c call_seq_now = temp_call_seq[:group_len] else: call_seq_now = call_seq[i, from_col:to_col] if auto_call_seq: # Sort by order value if not same_timing: raise ValueError("Cannot sort orders by value if they are executed at different times") for c in range(group_len): if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") col = from_col + c if np.isnan(main_info["size"][col]): continue # Approximate order value exec_state = ExecState( cash=last_cash[group] if cash_sharing else last_cash[col], position=last_position[col], debt=last_debt[col], locked_cash=last_locked_cash[col], free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col], val_price=last_val_price[col], value=last_value[group] if cash_sharing else last_value[col], ) temp_sort_by[c] = approx_order_value_nb( exec_state=exec_state, size=main_info["size"][col], size_type=main_info["size_type"][col], direction=main_info["direction"][col], ) insert_argsort_nb(temp_sort_by[:group_len], call_seq_now) else: if not same_bar_zone: # Sort by bar zone for c in range(group_len): if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") col = from_col + c if np.isnan(main_info["size"][col]): continue temp_sort_by[c] = main_info["bar_zone"][col] insert_argsort_nb(temp_sort_by[:group_len], call_seq_now) for k in range(group_len): if cash_sharing: c = call_seq_now[k] if c >= group_len: raise ValueError("Call index out of bounds of the group") else: c = k col = from_col + c if skip_empty and np.isnan(main_info["size"][col]): # shortcut continue # Get current values per column position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] # Generate the next order _i = main_info["idx"][col] if main_info["type"][col] == OrderType.Limit: _slippage = 0.0 else: _slippage = float(flex_select_nb(slippage_, _i, col)) _min_size = flex_select_nb(min_size_, _i, col) _max_size = flex_select_nb(max_size_, _i, col) _size_type = flex_select_nb(size_type_, _i, col) if _size_type != main_info["size_type"][col]: if not np.isnan(_min_size): _min_size, _ = resolve_size_nb( size=_min_size, size_type=_size_type, position=position_now, val_price=val_price_now, value=value_now, target_size_type=main_info["size_type"][col], as_requirement=True, ) if not np.isnan(_max_size): _max_size, _ = resolve_size_nb( size=_max_size, size_type=_size_type, position=position_now, val_price=val_price_now, value=value_now, target_size_type=main_info["size_type"][col], as_requirement=True, ) order = order_nb( size=main_info["size"][col], price=main_info["price"][col], size_type=main_info["size_type"][col], direction=main_info["direction"][col], fees=flex_select_nb(fees_, _i, col), fixed_fees=flex_select_nb(fixed_fees_, _i, col), slippage=_slippage, min_size=_min_size, max_size=_max_size, size_granularity=flex_select_nb(size_granularity_, _i, col), leverage=flex_select_nb(leverage_, _i, col), leverage_mode=flex_select_nb(leverage_mode_, _i, col), reject_prob=flex_select_nb(reject_prob_, _i, col), price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col), allow_partial=flex_select_nb(allow_partial_, _i, col), raise_reject=flex_select_nb(raise_reject_, _i, col), log=flex_select_nb(log_, _i, col), ) # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Append more order information if order_result.status == OrderStatus.Filled and order_counts[col] >= 1: order_records["signal_idx"][order_counts[col] - 1, col] = main_info["signal_idx"][col] order_records["creation_idx"][order_counts[col] - 1, col] = main_info["creation_idx"][col] order_records["type"][order_counts[col] - 1, col] = main_info["type"][col] order_records["stop_type"][order_counts[col] - 1, col] = main_info["stop_type"][col] # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash val_price_now = new_exec_state.val_price value_now = new_exec_state.value # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now last_cash[group] = cash_now last_free_cash[group] = free_cash_now last_value[group] = value_now last_return[group] = return_now for col in range(from_col, to_col): # Update valuation price using current close _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close _cash_earnings = flex_select_nb(cash_earnings_, i, col) _cash_dividends = flex_select_nb(cash_dividends_, i, col) _cash_earnings += _cash_dividends * last_position[col] last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings if track_cash_earnings: cash_earnings_out[i, col] += _cash_earnings if save_state: position[i, col] = last_position[col] debt[i, col] = last_debt[col] locked_cash[i, col] = last_locked_cash[col] cash[i, group] = last_cash[group] free_cash[i, group] = last_free_cash[group] # Update value and return group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - _cash_deposits, ) prev_close_value[group] = last_value[group] if save_value: in_outputs.value[i, group] = last_value[group] if save_returns: in_outputs.returns[i, group] = last_return[group] sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_out, cash_earnings=cash_earnings_out, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), index=None, freq=None, open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, init_cash=base_ch.FlexArraySlicer(), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=base_ch.FlexArraySlicer(axis=1), cash_earnings=base_ch.flex_array_gl_slicer, cash_dividends=base_ch.flex_array_gl_slicer, long_entries=base_ch.flex_array_gl_slicer, long_exits=base_ch.flex_array_gl_slicer, short_entries=base_ch.flex_array_gl_slicer, short_exits=base_ch.flex_array_gl_slicer, size=base_ch.flex_array_gl_slicer, price=base_ch.flex_array_gl_slicer, size_type=base_ch.flex_array_gl_slicer, fees=base_ch.flex_array_gl_slicer, fixed_fees=base_ch.flex_array_gl_slicer, slippage=base_ch.flex_array_gl_slicer, min_size=base_ch.flex_array_gl_slicer, max_size=base_ch.flex_array_gl_slicer, size_granularity=base_ch.flex_array_gl_slicer, leverage=base_ch.flex_array_gl_slicer, leverage_mode=base_ch.flex_array_gl_slicer, reject_prob=base_ch.flex_array_gl_slicer, price_area_vio_mode=base_ch.flex_array_gl_slicer, allow_partial=base_ch.flex_array_gl_slicer, raise_reject=base_ch.flex_array_gl_slicer, log=base_ch.flex_array_gl_slicer, val_price=base_ch.flex_array_gl_slicer, accumulate=base_ch.flex_array_gl_slicer, upon_long_conflict=base_ch.flex_array_gl_slicer, upon_short_conflict=base_ch.flex_array_gl_slicer, upon_dir_conflict=base_ch.flex_array_gl_slicer, upon_opposite_entry=base_ch.flex_array_gl_slicer, order_type=base_ch.flex_array_gl_slicer, limit_delta=base_ch.flex_array_gl_slicer, limit_tif=base_ch.flex_array_gl_slicer, limit_expiry=base_ch.flex_array_gl_slicer, limit_reverse=base_ch.flex_array_gl_slicer, limit_order_price=base_ch.flex_array_gl_slicer, upon_adj_limit_conflict=base_ch.flex_array_gl_slicer, upon_opp_limit_conflict=base_ch.flex_array_gl_slicer, use_stops=None, stop_ladder=None, sl_stop=base_ch.flex_array_gl_slicer, tsl_stop=base_ch.flex_array_gl_slicer, tsl_th=base_ch.flex_array_gl_slicer, tp_stop=base_ch.flex_array_gl_slicer, td_stop=base_ch.flex_array_gl_slicer, dt_stop=base_ch.flex_array_gl_slicer, stop_entry_price=base_ch.flex_array_gl_slicer, stop_exit_price=base_ch.flex_array_gl_slicer, stop_exit_type=base_ch.flex_array_gl_slicer, stop_order_type=base_ch.flex_array_gl_slicer, stop_limit_delta=base_ch.flex_array_gl_slicer, upon_stop_update=base_ch.flex_array_gl_slicer, upon_adj_stop_conflict=base_ch.flex_array_gl_slicer, upon_opp_stop_conflict=base_ch.flex_array_gl_slicer, delta_format=base_ch.flex_array_gl_slicer, time_delta_format=base_ch.flex_array_gl_slicer, from_ago=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), call_seq=base_ch.array_gl_slicer, auto_call_seq=None, ffill_val_price=None, update_value=None, save_state=None, save_value=None, save_returns=None, skip_empty=None, max_order_records=None, max_log_records=None, ), **portfolio_ch.merge_sim_outs_config, ) @register_jitted(cache=True, tags={"can_parallel"}) def from_signals_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, cash_dividends: tp.FlexArray2dLike = 0.0, long_entries: tp.FlexArray2dLike = False, long_exits: tp.FlexArray2dLike = False, short_entries: tp.FlexArray2dLike = False, short_exits: tp.FlexArray2dLike = False, size: tp.FlexArray2dLike = np.inf, price: tp.FlexArray2dLike = np.inf, size_type: tp.FlexArray2dLike = SizeType.Amount, fees: tp.FlexArray2dLike = 0.0, fixed_fees: tp.FlexArray2dLike = 0.0, slippage: tp.FlexArray2dLike = 0.0, min_size: tp.FlexArray2dLike = np.nan, max_size: tp.FlexArray2dLike = np.nan, size_granularity: tp.FlexArray2dLike = np.nan, leverage: tp.FlexArray2dLike = 1.0, leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy, reject_prob: tp.FlexArray2dLike = 0.0, price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore, allow_partial: tp.FlexArray2dLike = True, raise_reject: tp.FlexArray2dLike = False, log: tp.FlexArray2dLike = False, val_price: tp.FlexArray2dLike = np.inf, accumulate: tp.FlexArray2dLike = AccumulationMode.Disabled, upon_long_conflict: tp.FlexArray2dLike = ConflictMode.Ignore, upon_short_conflict: tp.FlexArray2dLike = ConflictMode.Ignore, upon_dir_conflict: tp.FlexArray2dLike = DirectionConflictMode.Ignore, upon_opposite_entry: tp.FlexArray2dLike = OppositeEntryMode.ReverseReduce, order_type: tp.FlexArray2dLike = OrderType.Market, limit_delta: tp.FlexArray2dLike = np.nan, limit_tif: tp.FlexArray2dLike = -1, limit_expiry: tp.FlexArray2dLike = -1, limit_reverse: tp.FlexArray2dLike = False, limit_order_price: tp.FlexArray2dLike = LimitOrderPrice.Limit, upon_adj_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepIgnore, upon_opp_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.CancelExecute, use_stops: bool = True, stop_ladder: int = StopLadderMode.Disabled, sl_stop: tp.FlexArray2dLike = np.nan, tsl_stop: tp.FlexArray2dLike = np.nan, tsl_th: tp.FlexArray2dLike = np.nan, tp_stop: tp.FlexArray2dLike = np.nan, td_stop: tp.FlexArray2dLike = -1, dt_stop: tp.FlexArray2dLike = -1, stop_entry_price: tp.FlexArray2dLike = StopEntryPrice.Close, stop_exit_price: tp.FlexArray2dLike = StopExitPrice.Stop, stop_exit_type: tp.FlexArray2dLike = StopExitType.Close, stop_order_type: tp.FlexArray2dLike = OrderType.Market, stop_limit_delta: tp.FlexArray2dLike = np.nan, upon_stop_update: tp.FlexArray2dLike = StopUpdateMode.Keep, upon_adj_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute, upon_opp_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute, delta_format: tp.FlexArray2dLike = DeltaFormat.Percent, time_delta_format: tp.FlexArray2dLike = TimeDeltaFormat.Index, from_ago: tp.FlexArray2dLike = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, call_seq: tp.Optional[tp.Array2d] = None, auto_call_seq: bool = False, ffill_val_price: bool = True, update_value: bool = False, save_state: bool = False, save_value: bool = False, save_returns: bool = False, skip_empty: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, ) -> SimulationOutput: """Simulate given signals. Iterates in the column-major order. Utilizes flexible broadcasting. !!! note Should be only grouped if cash sharing is enabled. """ check_group_lens_nb(group_lens, target_shape[1]) cash_sharing = is_grouped_nb(group_lens) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends)) long_entries_ = to_2d_array_nb(np.asarray(long_entries)) long_exits_ = to_2d_array_nb(np.asarray(long_exits)) short_entries_ = to_2d_array_nb(np.asarray(short_entries)) short_exits_ = to_2d_array_nb(np.asarray(short_exits)) size_ = to_2d_array_nb(np.asarray(size)) price_ = to_2d_array_nb(np.asarray(price)) size_type_ = to_2d_array_nb(np.asarray(size_type)) fees_ = to_2d_array_nb(np.asarray(fees)) fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees)) slippage_ = to_2d_array_nb(np.asarray(slippage)) min_size_ = to_2d_array_nb(np.asarray(min_size)) max_size_ = to_2d_array_nb(np.asarray(max_size)) size_granularity_ = to_2d_array_nb(np.asarray(size_granularity)) leverage_ = to_2d_array_nb(np.asarray(leverage)) leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode)) reject_prob_ = to_2d_array_nb(np.asarray(reject_prob)) price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode)) allow_partial_ = to_2d_array_nb(np.asarray(allow_partial)) raise_reject_ = to_2d_array_nb(np.asarray(raise_reject)) log_ = to_2d_array_nb(np.asarray(log)) val_price_ = to_2d_array_nb(np.asarray(val_price)) accumulate_ = to_2d_array_nb(np.asarray(accumulate)) upon_long_conflict_ = to_2d_array_nb(np.asarray(upon_long_conflict)) upon_short_conflict_ = to_2d_array_nb(np.asarray(upon_short_conflict)) upon_dir_conflict_ = to_2d_array_nb(np.asarray(upon_dir_conflict)) upon_opposite_entry_ = to_2d_array_nb(np.asarray(upon_opposite_entry)) order_type_ = to_2d_array_nb(np.asarray(order_type)) limit_delta_ = to_2d_array_nb(np.asarray(limit_delta)) limit_tif_ = to_2d_array_nb(np.asarray(limit_tif)) limit_expiry_ = to_2d_array_nb(np.asarray(limit_expiry)) limit_reverse_ = to_2d_array_nb(np.asarray(limit_reverse)) limit_order_price_ = to_2d_array_nb(np.asarray(limit_order_price)) upon_adj_limit_conflict_ = to_2d_array_nb(np.asarray(upon_adj_limit_conflict)) upon_opp_limit_conflict_ = to_2d_array_nb(np.asarray(upon_opp_limit_conflict)) sl_stop_ = to_2d_array_nb(np.asarray(sl_stop)) tsl_stop_ = to_2d_array_nb(np.asarray(tsl_stop)) tsl_th_ = to_2d_array_nb(np.asarray(tsl_th)) tp_stop_ = to_2d_array_nb(np.asarray(tp_stop)) td_stop_ = to_2d_array_nb(np.asarray(td_stop)) dt_stop_ = to_2d_array_nb(np.asarray(dt_stop)) stop_entry_price_ = to_2d_array_nb(np.asarray(stop_entry_price)) stop_exit_price_ = to_2d_array_nb(np.asarray(stop_exit_price)) stop_exit_type_ = to_2d_array_nb(np.asarray(stop_exit_type)) stop_order_type_ = to_2d_array_nb(np.asarray(stop_order_type)) stop_limit_delta_ = to_2d_array_nb(np.asarray(stop_limit_delta)) upon_stop_update_ = to_2d_array_nb(np.asarray(upon_stop_update)) upon_adj_stop_conflict_ = to_2d_array_nb(np.asarray(upon_adj_stop_conflict)) upon_opp_stop_conflict_ = to_2d_array_nb(np.asarray(upon_opp_stop_conflict)) delta_format_ = to_2d_array_nb(np.asarray(delta_format)) time_delta_format_ = to_2d_array_nb(np.asarray(time_delta_format)) from_ago_ = to_2d_array_nb(np.asarray(from_ago)) n_sl_steps = sl_stop_.shape[0] n_tsl_steps = tsl_stop_.shape[0] n_tp_steps = tp_stop_.shape[0] n_td_steps = td_stop_.shape[0] n_dt_steps = dt_stop_.shape[0] order_records, log_records = prepare_fs_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full(target_shape[1], 0.0, dtype=float_) last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1 if track_cash_deposits: cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_) else: cash_deposits_out = np.full((1, 1), 0.0, dtype=float_) track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1 track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1 track_cash_earnings = track_cash_earnings or track_cash_dividends if track_cash_earnings: cash_earnings_out = np.full(target_shape, 0.0, dtype=float_) else: cash_earnings_out = np.full((1, 1), 0.0, dtype=float_) if save_state: cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) position = np.full(target_shape, np.nan, dtype=float_) debt = np.full(target_shape, np.nan, dtype=float_) locked_cash = np.full(target_shape, np.nan, dtype=float_) free_cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: cash = np.full((0, 0), np.nan, dtype=float_) position = np.full((0, 0), np.nan, dtype=float_) debt = np.full((0, 0), np.nan, dtype=float_) locked_cash = np.full((0, 0), np.nan, dtype=float_) free_cash = np.full((0, 0), np.nan, dtype=float_) if save_value: value = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: value = np.full((0, 0), np.nan, dtype=float_) if save_returns: returns = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_) else: returns = np.full((0, 0), np.nan, dtype=float_) in_outputs = FSInOutputs( cash=cash, position=position, debt=debt, locked_cash=locked_cash, free_cash=free_cash, value=value, returns=returns, ) last_limit_info = np.empty(target_shape[1], dtype=limit_info_dt) last_limit_info["signal_idx"][:] = -1 last_limit_info["creation_idx"][:] = -1 last_limit_info["init_idx"][:] = -1 last_limit_info["init_price"][:] = np.nan last_limit_info["init_size"][:] = np.nan last_limit_info["init_size_type"][:] = -1 last_limit_info["init_direction"][:] = -1 last_limit_info["init_stop_type"][:] = -1 last_limit_info["delta"][:] = np.nan last_limit_info["delta_format"][:] = -1 last_limit_info["tif"][:] = -1 last_limit_info["expiry"][:] = -1 last_limit_info["time_delta_format"][:] = -1 last_limit_info["reverse"][:] = False last_limit_info["order_price"][:] = np.nan if use_stops: last_sl_info = np.empty(target_shape[1], dtype=sl_info_dt) last_sl_info["init_idx"][:] = -1 last_sl_info["init_price"][:] = np.nan last_sl_info["init_position"][:] = np.nan last_sl_info["stop"][:] = np.nan last_sl_info["exit_price"][:] = -1 last_sl_info["exit_size"][:] = np.nan last_sl_info["exit_size_type"][:] = -1 last_sl_info["exit_type"][:] = -1 last_sl_info["order_type"][:] = -1 last_sl_info["limit_delta"][:] = np.nan last_sl_info["delta_format"][:] = -1 last_sl_info["ladder"][:] = -1 last_sl_info["step"][:] = -1 last_sl_info["step_idx"][:] = -1 last_tsl_info = np.empty(target_shape[1], dtype=tsl_info_dt) last_tsl_info["init_idx"][:] = -1 last_tsl_info["init_price"][:] = np.nan last_tsl_info["init_position"][:] = np.nan last_tsl_info["peak_idx"][:] = -1 last_tsl_info["peak_price"][:] = np.nan last_tsl_info["stop"][:] = np.nan last_tsl_info["th"][:] = np.nan last_tsl_info["exit_price"][:] = -1 last_tsl_info["exit_size"][:] = np.nan last_tsl_info["exit_size_type"][:] = -1 last_tsl_info["exit_type"][:] = -1 last_tsl_info["order_type"][:] = -1 last_tsl_info["limit_delta"][:] = np.nan last_tsl_info["delta_format"][:] = -1 last_tsl_info["ladder"][:] = -1 last_tsl_info["step"][:] = -1 last_tsl_info["step_idx"][:] = -1 last_tp_info = np.empty(target_shape[1], dtype=tp_info_dt) last_tp_info["init_idx"][:] = -1 last_tp_info["init_price"][:] = np.nan last_tp_info["init_position"][:] = np.nan last_tp_info["stop"][:] = np.nan last_tp_info["exit_price"][:] = -1 last_tp_info["exit_size"][:] = np.nan last_tp_info["exit_size_type"][:] = -1 last_tp_info["exit_type"][:] = -1 last_tp_info["order_type"][:] = -1 last_tp_info["limit_delta"][:] = np.nan last_tp_info["delta_format"][:] = -1 last_tp_info["ladder"][:] = -1 last_tp_info["step"][:] = -1 last_tp_info["step_idx"][:] = -1 last_td_info = np.empty(target_shape[1], dtype=time_info_dt) last_td_info["init_idx"][:] = -1 last_td_info["init_position"][:] = np.nan last_td_info["stop"][:] = -1 last_td_info["exit_price"][:] = -1 last_td_info["exit_size"][:] = np.nan last_td_info["exit_size_type"][:] = -1 last_td_info["exit_type"][:] = -1 last_td_info["order_type"][:] = -1 last_td_info["limit_delta"][:] = np.nan last_td_info["delta_format"][:] = -1 last_td_info["time_delta_format"][:] = -1 last_td_info["ladder"][:] = -1 last_td_info["step"][:] = -1 last_td_info["step_idx"][:] = -1 last_dt_info = np.empty(target_shape[1], dtype=time_info_dt) last_dt_info["init_idx"][:] = -1 last_dt_info["init_position"][:] = np.nan last_dt_info["stop"][:] = -1 last_dt_info["exit_price"][:] = -1 last_dt_info["exit_size"][:] = np.nan last_dt_info["exit_size_type"][:] = -1 last_dt_info["exit_type"][:] = -1 last_dt_info["order_type"][:] = -1 last_dt_info["limit_delta"][:] = np.nan last_dt_info["delta_format"][:] = -1 last_dt_info["time_delta_format"][:] = -1 last_dt_info["ladder"][:] = -1 last_dt_info["step"][:] = -1 last_dt_info["step_idx"][:] = -1 else: last_sl_info = np.empty(0, dtype=sl_info_dt) last_tsl_info = np.empty(0, dtype=tsl_info_dt) last_tp_info = np.empty(0, dtype=tp_info_dt) last_td_info = np.empty(0, dtype=time_info_dt) last_dt_info = np.empty(0, dtype=time_info_dt) last_signal = np.empty(target_shape[1], dtype=int_) main_info = np.empty(target_shape[1], dtype=main_info_dt) temp_call_seq = np.empty(target_shape[1], dtype=int_) temp_sort_by = np.empty(target_shape[1], dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col _sim_start = sim_start_[group] _sim_end = sim_end_[group] for i in range(_sim_start, _sim_end): # Add cash _cash_deposits = flex_select_nb(cash_deposits_, i, group) if _cash_deposits < 0: _cash_deposits = max(_cash_deposits, -last_cash[group]) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits if track_cash_deposits: cash_deposits_out[i, group] += _cash_deposits for c in range(group_len): col = from_col + c # Update valuation price using current open _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Update value and return group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - _cash_deposits, ) # Get signals skip = skip_empty for c in range(group_len): col = from_col + c if flex_select_nb(log_, i, col): skip = False _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: is_long_entry = False is_long_exit = False is_short_entry = False is_short_exit = False else: is_long_entry = flex_select_nb(long_entries_, _i, col) is_long_exit = flex_select_nb(long_exits_, _i, col) is_short_entry = flex_select_nb(short_entries_, _i, col) is_short_exit = flex_select_nb(short_exits_, _i, col) limit_signal = is_limit_active_nb( init_idx=last_limit_info["init_idx"][col], init_price=last_limit_info["init_price"][col], ) if not use_stops: sl_stop_signal = False tsl_stop_signal = False tp_stop_signal = False td_stop_signal = False dt_stop_signal = False else: sl_stop_signal = is_stop_active_nb( init_idx=last_sl_info["init_idx"][col], stop=last_sl_info["stop"][col], ) tsl_stop_signal = is_stop_active_nb( init_idx=last_tsl_info["init_idx"][col], stop=last_tsl_info["stop"][col], ) tp_stop_signal = is_stop_active_nb( init_idx=last_tp_info["init_idx"][col], stop=last_tp_info["stop"][col], ) td_stop_signal = is_time_stop_active_nb( init_idx=last_td_info["init_idx"][col], stop=last_td_info["stop"][col], ) dt_stop_signal = is_time_stop_active_nb( init_idx=last_dt_info["init_idx"][col], stop=last_dt_info["stop"][col], ) # Pack signals into a single integer last_signal[col] = ( (is_long_entry << 10) | (is_long_exit << 9) | (is_short_entry << 8) | (is_short_exit << 7) | (limit_signal << 6) | (sl_stop_signal << 5) | (tsl_stop_signal << 4) | (tp_stop_signal << 3) | (td_stop_signal << 2) | (dt_stop_signal << 1) ) if last_signal[col] > 0: skip = False if not skip: # Get size and value of each order for c in range(group_len): col = from_col + c # Set defaults main_info["bar_zone"][col] = -1 main_info["signal_idx"][col] = -1 main_info["creation_idx"][col] = -1 main_info["idx"][col] = i main_info["val_price"][col] = np.nan main_info["price"][col] = np.nan main_info["size"][col] = np.nan main_info["size_type"][col] = -1 main_info["direction"][col] = -1 main_info["type"][col] = -1 main_info["stop_type"][col] = -1 temp_sort_by[col] = 0.0 # Unpack a single integer into signals is_long_entry = (last_signal[col] >> 10) & 1 is_long_exit = (last_signal[col] >> 9) & 1 is_short_entry = (last_signal[col] >> 8) & 1 is_short_exit = (last_signal[col] >> 7) & 1 limit_signal = (last_signal[col] >> 6) & 1 sl_stop_signal = (last_signal[col] >> 5) & 1 tsl_stop_signal = (last_signal[col] >> 4) & 1 tp_stop_signal = (last_signal[col] >> 3) & 1 td_stop_signal = (last_signal[col] >> 2) & 1 dt_stop_signal = (last_signal[col] >> 1) & 1 any_user_signal = is_long_entry or is_long_exit or is_short_entry or is_short_exit any_limit_signal = limit_signal any_stop_signal = ( sl_stop_signal or tsl_stop_signal or tp_stop_signal or td_stop_signal or dt_stop_signal ) # Set initial info exec_limit_set = False exec_limit_set_on_open = False exec_limit_set_on_close = False exec_limit_signal_i = -1 exec_limit_creation_i = -1 exec_limit_init_i = -1 exec_limit_val_price = np.nan exec_limit_price = np.nan exec_limit_size = np.nan exec_limit_size_type = -1 exec_limit_direction = -1 exec_limit_stop_type = -1 exec_limit_bar_zone = -1 exec_stop_set = False exec_stop_set_on_open = False exec_stop_set_on_close = False exec_stop_init_i = -1 exec_stop_val_price = np.nan exec_stop_price = np.nan exec_stop_size = np.nan exec_stop_size_type = -1 exec_stop_direction = -1 exec_stop_type = -1 exec_stop_stop_type = -1 exec_stop_delta = np.nan exec_stop_delta_format = -1 exec_stop_make_limit = False exec_stop_bar_zone = -1 user_on_open = False user_on_close = False exec_user_set = False exec_user_val_price = np.nan exec_user_price = np.nan exec_user_size = np.nan exec_user_size_type = -1 exec_user_direction = -1 exec_user_type = -1 exec_user_stop_type = -1 exec_user_make_limit = False exec_user_bar_zone = -1 # Resolve the current bar _i = i - abs(flex_select_nb(from_ago_, i, col)) _open = flex_select_nb(open_, i, col) _high = flex_select_nb(high_, i, col) _low = flex_select_nb(low_, i, col) _close = flex_select_nb(close_, i, col) _high, _low = resolve_hl_nb( open=_open, high=_high, low=_low, close=_close, ) # Process the limit signal if any_limit_signal: # Check whether the limit price was hit _signal_i = last_limit_info["signal_idx"][col] _creation_i = last_limit_info["creation_idx"][col] _init_i = last_limit_info["init_idx"][col] _price = last_limit_info["init_price"][col] _size = last_limit_info["init_size"][col] _size_type = last_limit_info["init_size_type"][col] _direction = last_limit_info["init_direction"][col] _stop_type = last_limit_info["init_stop_type"][col] _delta = last_limit_info["delta"][col] _delta_format = last_limit_info["delta_format"][col] _tif = last_limit_info["tif"][col] _expiry = last_limit_info["expiry"][col] _time_delta_format = last_limit_info["time_delta_format"][col] _reverse = last_limit_info["reverse"][col] _order_price = last_limit_info["order_price"][col] limit_expired_on_open, limit_expired = check_limit_expired_nb( creation_idx=_creation_i, i=i, tif=_tif, expiry=_expiry, time_delta_format=_time_delta_format, index=index, freq=freq, ) limit_price, limit_hit_on_open, limit_hit = check_limit_hit_nb( open=_open, high=_high, low=_low, close=_close, price=_price, size=_size, direction=_direction, limit_delta=_delta, delta_format=_delta_format, limit_reverse=_reverse, can_use_ohlc=True, check_open=True, hard_limit=_order_price == LimitOrderPrice.HardLimit, ) # Resolve the price limit_price = resolve_limit_order_price_nb( limit_price=limit_price, close=_close, limit_order_price=_order_price, ) if limit_expired_on_open or (not limit_hit_on_open and limit_expired): # Expired limit signal any_limit_signal = False last_limit_info["signal_idx"][col] = -1 last_limit_info["creation_idx"][col] = -1 last_limit_info["init_idx"][col] = -1 last_limit_info["init_price"][col] = np.nan last_limit_info["init_size"][col] = np.nan last_limit_info["init_size_type"][col] = -1 last_limit_info["init_direction"][col] = -1 last_limit_info["delta"][col] = np.nan last_limit_info["delta_format"][col] = -1 last_limit_info["tif"][col] = -1 last_limit_info["expiry"][col] = -1 last_limit_info["time_delta_format"][col] = -1 last_limit_info["reverse"][col] = False last_limit_info["order_price"][col] = np.nan else: # Save info if limit_hit: # Executable limit signal exec_limit_set = True exec_limit_set_on_open = limit_hit_on_open exec_limit_set_on_close = _order_price == LimitOrderPrice.Close exec_limit_signal_i = _signal_i exec_limit_creation_i = _creation_i exec_limit_init_i = _init_i if np.isinf(limit_price) and limit_price > 0: exec_limit_val_price = _close elif np.isinf(limit_price) and limit_price < 0: exec_limit_val_price = _open else: exec_limit_val_price = limit_price exec_limit_price = limit_price exec_limit_size = _size exec_limit_size_type = _size_type exec_limit_direction = _direction exec_limit_stop_type = _stop_type # Process the stop signal if any_stop_signal: # Check SL sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = np.nan, False, False if sl_stop_signal: # Check against high and low sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_sl_info["init_price"][col], stop=last_sl_info["stop"][col], delta_format=last_sl_info["delta_format"][col], hit_below=True, hard_stop=last_sl_info["exit_price"][col] == StopExitPrice.HardStop, ) # Check TSL and TTP tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = np.nan, False, False if tsl_stop_signal: # Update peak price using open if last_position[col] > 0: if _open > last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _open else: if _open < last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _open # Check threshold against previous bars and open if np.isnan(last_tsl_info["th"][col]): th_hit = True else: th_hit = check_tsl_th_hit_nb( is_position_long=last_position[col] > 0, init_price=last_tsl_info["init_price"][col], peak_price=last_tsl_info["peak_price"][col], threshold=last_tsl_info["th"][col], delta_format=last_tsl_info["delta_format"][col], ) if th_hit: tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_tsl_info["peak_price"][col], stop=last_tsl_info["stop"][col], delta_format=last_tsl_info["delta_format"][col], hit_below=True, hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop, ) # Update peak price using full bar if last_position[col] > 0: if _high > last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _high - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _high else: if _low < last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _low - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _low if not tsl_stop_hit: # Check threshold against full bar if not th_hit: if np.isnan(last_tsl_info["th"][col]): th_hit = True else: th_hit = check_tsl_th_hit_nb( is_position_long=last_position[col] > 0, init_price=last_tsl_info["init_price"][col], peak_price=last_tsl_info["peak_price"][col], threshold=last_tsl_info["th"][col], delta_format=last_tsl_info["delta_format"][col], ) if th_hit: # Check threshold against close tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_tsl_info["peak_price"][col], stop=last_tsl_info["stop"][col], delta_format=last_tsl_info["delta_format"][col], hit_below=True, can_use_ohlc=False, hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop, ) # Check TP tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = np.nan, False, False if tp_stop_signal: tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_tp_info["init_price"][col], stop=last_tp_info["stop"][col], delta_format=last_tp_info["delta_format"][col], hit_below=False, hard_stop=last_tp_info["exit_price"][col] == StopExitPrice.HardStop, ) # Check TD td_stop_price, td_stop_hit_on_open, td_stop_hit = np.nan, False, False if td_stop_signal: td_stop_hit_on_open, td_stop_hit = check_td_stop_hit_nb( init_idx=last_td_info["init_idx"][col], i=i, stop=last_td_info["stop"][col], time_delta_format=last_td_info["time_delta_format"][col], index=index, freq=freq, ) if np.isnan(_open): td_stop_hit_on_open = False if td_stop_hit_on_open: td_stop_price = _open else: td_stop_price = _close # Check DT dt_stop_price, dt_stop_hit_on_open, dt_stop_hit = np.nan, False, False if dt_stop_signal: dt_stop_hit_on_open, dt_stop_hit = check_dt_stop_hit_nb( i=i, stop=last_dt_info["stop"][col], time_delta_format=last_dt_info["time_delta_format"][col], index=index, freq=freq, ) if np.isnan(_open): dt_stop_hit_on_open = False if dt_stop_hit_on_open: dt_stop_price = _open else: dt_stop_price = _close # Resolve the stop signal sl_hit = False tsl_hit = False tp_hit = False td_hit = False dt_hit = False if sl_stop_hit_on_open: sl_hit = True elif tsl_stop_hit_on_open: tsl_hit = True elif tp_stop_hit_on_open: tp_hit = True elif td_stop_hit_on_open: td_hit = True elif dt_stop_hit_on_open: dt_hit = True elif sl_stop_hit: sl_hit = True elif tsl_stop_hit: tsl_hit = True elif tp_stop_hit: tp_hit = True elif td_stop_hit: td_hit = True elif dt_stop_hit: dt_hit = True if sl_hit: stop_price, stop_hit_on_open, stop_hit = sl_stop_price, sl_stop_hit_on_open, sl_stop_hit _stop_type = StopType.SL _init_i = last_sl_info["init_idx"][col] _stop_exit_price = last_sl_info["exit_price"][col] _stop_exit_size = last_sl_info["exit_size"][col] _stop_exit_size_type = last_sl_info["exit_size_type"][col] _stop_exit_type = last_sl_info["exit_type"][col] _stop_order_type = last_sl_info["order_type"][col] _limit_delta = last_sl_info["limit_delta"][col] _delta_format = last_sl_info["delta_format"][col] _ladder = last_sl_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_sl_info["step"][col] if step < n_sl_steps: _stop_exit_size = get_stop_ladder_exit_size_nb( stop_=sl_stop_, step=step, col=col, init_price=last_sl_info["init_price"][col], init_position=last_sl_info["init_position"][col], position_now=last_position[col], ladder=_ladder, delta_format=last_sl_info["delta_format"][col], hit_below=True, ) _stop_exit_size_type = SizeType.Amount elif tsl_hit: stop_price, stop_hit_on_open, stop_hit = ( tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit, ) if np.isnan(last_tsl_info["th"][col]): _stop_type = StopType.TSL else: _stop_type = StopType.TTP _init_i = last_tsl_info["init_idx"][col] _stop_exit_price = last_tsl_info["exit_price"][col] _stop_exit_size = last_tsl_info["exit_size"][col] _stop_exit_size_type = last_tsl_info["exit_size_type"][col] _stop_exit_type = last_tsl_info["exit_type"][col] _stop_order_type = last_tsl_info["order_type"][col] _limit_delta = last_tsl_info["limit_delta"][col] _delta_format = last_tsl_info["delta_format"][col] _ladder = last_tsl_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_tsl_info["step"][col] if step < n_tsl_steps: _stop_exit_size = get_stop_ladder_exit_size_nb( stop_=tsl_stop_, step=step, col=col, init_price=last_tsl_info["init_price"][col], init_position=last_tsl_info["init_position"][col], position_now=last_position[col], ladder=_ladder, delta_format=last_tsl_info["delta_format"][col], hit_below=True, ) _stop_exit_size_type = SizeType.Amount elif tp_hit: stop_price, stop_hit_on_open, stop_hit = tp_stop_price, tp_stop_hit_on_open, tp_stop_hit _stop_type = StopType.TP _init_i = last_tp_info["init_idx"][col] _stop_exit_price = last_tp_info["exit_price"][col] _stop_exit_size = last_tp_info["exit_size"][col] _stop_exit_size_type = last_tp_info["exit_size_type"][col] _stop_exit_type = last_tp_info["exit_type"][col] _stop_order_type = last_tp_info["order_type"][col] _limit_delta = last_tp_info["limit_delta"][col] _delta_format = last_tp_info["delta_format"][col] _ladder = last_tp_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_tp_info["step"][col] if step < n_tp_steps: _stop_exit_size = get_stop_ladder_exit_size_nb( stop_=tp_stop_, step=step, col=col, init_price=last_tp_info["init_price"][col], init_position=last_tp_info["init_position"][col], position_now=last_position[col], ladder=_ladder, delta_format=last_tp_info["delta_format"][col], hit_below=True, ) _stop_exit_size_type = SizeType.Amount elif td_hit: stop_price, stop_hit_on_open, stop_hit = td_stop_price, td_stop_hit_on_open, td_stop_hit _stop_type = StopType.TD _init_i = last_td_info["init_idx"][col] _stop_exit_price = last_td_info["exit_price"][col] _stop_exit_size = last_td_info["exit_size"][col] _stop_exit_size_type = last_td_info["exit_size_type"][col] _stop_exit_type = last_td_info["exit_type"][col] _stop_order_type = last_td_info["order_type"][col] _limit_delta = last_td_info["limit_delta"][col] _delta_format = last_td_info["delta_format"][col] _ladder = last_td_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_td_info["step"][col] if step < n_td_steps: _stop_exit_size = get_time_stop_ladder_exit_size_nb( stop_=td_stop_, step=step, col=col, init_idx=last_td_info["init_idx"][col], init_position=last_td_info["init_position"][col], position_now=last_position[col], ladder=_ladder, time_delta_format=last_td_info["time_delta_format"][col], index=index, ) _stop_exit_size_type = SizeType.Amount elif dt_hit: stop_price, stop_hit_on_open, stop_hit = dt_stop_price, dt_stop_hit_on_open, dt_stop_hit _stop_type = StopType.DT _init_i = last_dt_info["init_idx"][col] _stop_exit_price = last_dt_info["exit_price"][col] _stop_exit_size = last_dt_info["exit_size"][col] _stop_exit_size_type = last_dt_info["exit_size_type"][col] _stop_exit_type = last_dt_info["exit_type"][col] _stop_order_type = last_dt_info["order_type"][col] _limit_delta = last_dt_info["limit_delta"][col] _delta_format = last_dt_info["delta_format"][col] _ladder = last_dt_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_dt_info["step"][col] if step < n_dt_steps: _stop_exit_size = get_time_stop_ladder_exit_size_nb( stop_=dt_stop_, step=step, col=col, init_idx=last_dt_info["init_idx"][col], init_position=last_dt_info["init_position"][col], position_now=last_position[col], ladder=_ladder, time_delta_format=last_dt_info["time_delta_format"][col], index=index, ) _stop_exit_size_type = SizeType.Amount else: stop_price, stop_hit_on_open, stop_hit = np.nan, False, False if stop_hit: # Stop price was hit # Resolve the final stop signal _accumulate = flex_select_nb(accumulate_, i, col) _size = flex_select_nb(size_, i, col) _size_type = flex_select_nb(size_type_, i, col) if not np.isnan(_stop_exit_size): _accumulate = True if _stop_exit_type == StopExitType.Close: _stop_exit_type = StopExitType.CloseReduce _size = _stop_exit_size if _stop_exit_size_type != -1: _size_type = _stop_exit_size_type ( stop_is_long_entry, stop_is_long_exit, stop_is_short_entry, stop_is_short_exit, _accumulate, ) = generate_stop_signal_nb( position_now=last_position[col], stop_exit_type=_stop_exit_type, accumulate=_accumulate, ) # Resolve the price _price = resolve_stop_exit_price_nb( stop_price=stop_price, close=_close, stop_exit_price=_stop_exit_price, ) # Convert both signals to size (direction-aware), size type, and direction _size, _size_type, _direction = signal_to_size_nb( position_now=last_position[col], val_price_now=_price, value_now=last_value[group], is_long_entry=stop_is_long_entry, is_long_exit=stop_is_long_exit, is_short_entry=stop_is_short_entry, is_short_exit=stop_is_short_exit, size=_size, size_type=_size_type, accumulate=_accumulate, ) if not np.isnan(_size): # Executable stop signal can_execute = True if _stop_order_type == OrderType.Limit: # Use close to check whether the limit price was hit if _stop_exit_price == StopExitPrice.Close: # Cannot place a limit order at the close price and execute right away can_execute = False if can_execute: limit_price, _, can_execute = check_limit_hit_nb( open=_open, high=_high, low=_low, close=_close, price=_price, size=_size, direction=_direction, limit_delta=_limit_delta, delta_format=_delta_format, limit_reverse=False, can_use_ohlc=stop_hit_on_open, check_open=False, hard_limit=False, ) if can_execute: _price = limit_price # Save info exec_stop_set = True exec_stop_set_on_open = stop_hit_on_open exec_stop_set_on_close = _stop_exit_price == StopExitPrice.Close exec_stop_init_i = _init_i if np.isinf(_price) and _price > 0: exec_stop_val_price = _close elif np.isinf(_price) and _price < 0: exec_stop_val_price = _open else: exec_stop_val_price = _price exec_stop_price = _price exec_stop_size = _size exec_stop_size_type = _size_type exec_stop_direction = _direction exec_stop_type = _stop_order_type exec_stop_stop_type = _stop_type exec_stop_delta = _limit_delta exec_stop_delta_format = _delta_format exec_stop_make_limit = not can_execute # Process user signal if any_user_signal: if _i < 0: _price = np.nan _size = np.nan _size_type = -1 _direction = -1 else: _accumulate = flex_select_nb(accumulate_, _i, col) if is_long_entry or is_short_entry: # Resolve any single-direction conflicts _upon_long_conflict = flex_select_nb(upon_long_conflict_, _i, col) is_long_entry, is_long_exit = resolve_signal_conflict_nb( position_now=last_position[col], is_entry=is_long_entry, is_exit=is_long_exit, direction=Direction.LongOnly, conflict_mode=_upon_long_conflict, ) _upon_short_conflict = flex_select_nb(upon_short_conflict_, _i, col) is_short_entry, is_short_exit = resolve_signal_conflict_nb( position_now=last_position[col], is_entry=is_short_entry, is_exit=is_short_exit, direction=Direction.ShortOnly, conflict_mode=_upon_short_conflict, ) # Resolve any multi-direction conflicts _upon_dir_conflict = flex_select_nb(upon_dir_conflict_, _i, col) is_long_entry, is_short_entry = resolve_dir_conflict_nb( position_now=last_position[col], is_long_entry=is_long_entry, is_short_entry=is_short_entry, upon_dir_conflict=_upon_dir_conflict, ) # Resolve an opposite entry _upon_opposite_entry = flex_select_nb(upon_opposite_entry_, _i, col) ( is_long_entry, is_long_exit, is_short_entry, is_short_exit, _accumulate, ) = resolve_opposite_entry_nb( position_now=last_position[col], is_long_entry=is_long_entry, is_long_exit=is_long_exit, is_short_entry=is_short_entry, is_short_exit=is_short_exit, upon_opposite_entry=_upon_opposite_entry, accumulate=_accumulate, ) # Resolve the price _price = flex_select_nb(price_, _i, col) # Convert both signals to size (direction-aware), size type, and direction _val_price = flex_select_nb(val_price_, i, col) if np.isinf(_val_price) and _val_price > 0: if np.isinf(_price) and _price > 0: _val_price = _close elif np.isinf(_price) and _price < 0: _val_price = _open else: _val_price = _price elif np.isnan(_val_price) or (np.isinf(_val_price) and _val_price < 0): _val_price = last_val_price[col] _size, _size_type, _direction = signal_to_size_nb( position_now=last_position[col], val_price_now=_val_price, value_now=last_value[group], is_long_entry=is_long_entry, is_long_exit=is_long_exit, is_short_entry=is_short_entry, is_short_exit=is_short_exit, size=flex_select_nb(size_, _i, col), size_type=flex_select_nb(size_type_, _i, col), accumulate=_accumulate, ) if np.isinf(_price): if _price > 0: user_on_close = True else: user_on_open = True if not np.isnan(_size): # Executable user signal can_execute = True _order_type = flex_select_nb(order_type_, _i, col) if _order_type == OrderType.Limit: # Use close to check whether the limit price was hit can_use_ohlc = False if np.isinf(_price): if _price > 0: # Cannot place a limit order at the close price and execute right away _price = _close can_execute = False else: can_use_ohlc = True _price = _open if can_execute: _limit_delta = flex_select_nb(limit_delta_, _i, col) _delta_format = flex_select_nb(delta_format_, _i, col) _limit_reverse = flex_select_nb(limit_reverse_, _i, col) limit_price, _, can_execute = check_limit_hit_nb( open=_open, high=_high, low=_low, close=_close, price=_price, size=_size, direction=_direction, limit_delta=_limit_delta, delta_format=_delta_format, limit_reverse=_limit_reverse, can_use_ohlc=can_use_ohlc, check_open=False, hard_limit=False, ) if can_execute: _price = limit_price # Save info exec_user_set = True exec_user_val_price = _val_price exec_user_price = _price exec_user_size = _size exec_user_size_type = _size_type exec_user_direction = _direction exec_user_type = _order_type exec_user_stop_type = -1 exec_user_make_limit = not can_execute if ( exec_limit_set or exec_stop_set or exec_user_set or ((any_limit_signal or any_stop_signal) and any_user_signal) ): # Choose the main executable signal # Priority: limit -> stop -> user # Check whether the main signal comes on open keep_limit = True keep_stop = True execute_limit = False execute_stop = False execute_user = False if exec_limit_set_on_open: keep_limit = False keep_stop = False execute_limit = True if exec_limit_set_on_close: exec_limit_bar_zone = BarZone.Close else: exec_limit_bar_zone = BarZone.Open elif exec_stop_set_on_open: keep_limit = False keep_stop = _ladder execute_stop = True if exec_stop_set_on_close: exec_stop_bar_zone = BarZone.Close else: exec_stop_bar_zone = BarZone.Open elif any_user_signal and user_on_open: execute_user = True if any_limit_signal and (execute_user or not exec_user_set): stop_size = get_diraware_size_nb( size=last_limit_info["init_size"][col], direction=last_limit_info["init_direction"][col], ) keep_limit, execute_user = resolve_pending_conflict_nb( is_pending_long=stop_size >= 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col), ) if any_stop_signal and (execute_user or not exec_user_set): keep_stop, execute_user = resolve_pending_conflict_nb( is_pending_long=last_position[col] < 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col), ) if not exec_user_set: execute_user = False if execute_user: exec_user_bar_zone = BarZone.Open if not execute_limit and not execute_stop and not execute_user: # Check whether the main signal comes in the middle of the bar if exec_limit_set and not exec_limit_set_on_open and keep_limit: keep_limit = False keep_stop = False execute_limit = True exec_limit_bar_zone = BarZone.Middle elif ( exec_stop_set and not exec_stop_set_on_open and not exec_stop_set_on_close and keep_stop ): keep_limit = False keep_stop = _ladder execute_stop = True exec_stop_bar_zone = BarZone.Middle elif any_user_signal and not user_on_open and not user_on_close: execute_user = True if any_limit_signal and keep_limit and (execute_user or not exec_user_set): stop_size = get_diraware_size_nb( size=last_limit_info["init_size"][col], direction=last_limit_info["init_direction"][col], ) keep_limit, execute_user = resolve_pending_conflict_nb( is_pending_long=stop_size >= 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col), ) if any_stop_signal and keep_stop and (execute_user or not exec_user_set): keep_stop, execute_user = resolve_pending_conflict_nb( is_pending_long=last_position[col] < 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col), ) if not exec_user_set: execute_user = False if execute_user: exec_user_bar_zone = BarZone.Middle if not execute_limit and not execute_stop and not execute_user: # Check whether the main signal comes on close if exec_stop_set_on_close and keep_stop: keep_limit = False keep_stop = _ladder execute_stop = True exec_stop_bar_zone = BarZone.Close elif any_user_signal and user_on_close: execute_user = True if any_limit_signal and keep_limit and (execute_user or not exec_user_set): stop_size = get_diraware_size_nb( size=last_limit_info["init_size"][col], direction=last_limit_info["init_direction"][col], ) keep_limit, execute_user = resolve_pending_conflict_nb( is_pending_long=stop_size >= 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col), ) if any_stop_signal and keep_stop and (execute_user or not exec_user_set): keep_stop, execute_user = resolve_pending_conflict_nb( is_pending_long=last_position[col] < 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col), ) if not exec_user_set: execute_user = False if execute_user: exec_user_bar_zone = BarZone.Close # Process the limit signal if execute_limit: # Execute the signal main_info["bar_zone"][col] = exec_limit_bar_zone main_info["signal_idx"][col] = exec_limit_signal_i main_info["creation_idx"][col] = exec_limit_creation_i main_info["idx"][col] = exec_limit_init_i main_info["val_price"][col] = exec_limit_val_price main_info["price"][col] = exec_limit_price main_info["size"][col] = exec_limit_size main_info["size_type"][col] = exec_limit_size_type main_info["direction"][col] = exec_limit_direction main_info["type"][col] = OrderType.Limit main_info["stop_type"][col] = exec_limit_stop_type if execute_limit or (any_limit_signal and not keep_limit): # Clear the pending info any_limit_signal = False last_limit_info["signal_idx"][col] = -1 last_limit_info["creation_idx"][col] = -1 last_limit_info["init_idx"][col] = -1 last_limit_info["init_price"][col] = np.nan last_limit_info["init_size"][col] = np.nan last_limit_info["init_size_type"][col] = -1 last_limit_info["init_direction"][col] = -1 last_limit_info["init_stop_type"][col] = -1 last_limit_info["delta"][col] = np.nan last_limit_info["delta_format"][col] = -1 last_limit_info["tif"][col] = -1 last_limit_info["expiry"][col] = -1 last_limit_info["time_delta_format"][col] = -1 last_limit_info["reverse"][col] = False last_limit_info["order_price"][col] = np.nan # Process the stop signal if execute_stop: # Execute the signal if exec_stop_make_limit: if any_limit_signal: raise ValueError("Only one active limit signal is allowed at a time") _limit_tif = flex_select_nb(limit_tif_, i, col) _limit_expiry = flex_select_nb(limit_expiry_, i, col) _time_delta_format = flex_select_nb(time_delta_format_, i, col) _limit_order_price = flex_select_nb(limit_order_price_, i, col) last_limit_info["signal_idx"][col] = exec_stop_init_i last_limit_info["creation_idx"][col] = i last_limit_info["init_idx"][col] = i last_limit_info["init_price"][col] = exec_stop_price last_limit_info["init_size"][col] = exec_stop_size last_limit_info["init_size_type"][col] = exec_stop_size_type last_limit_info["init_direction"][col] = exec_stop_direction last_limit_info["init_stop_type"][col] = exec_stop_stop_type last_limit_info["delta"][col] = exec_stop_delta last_limit_info["delta_format"][col] = exec_stop_delta_format last_limit_info["tif"][col] = _limit_tif last_limit_info["expiry"][col] = _limit_expiry last_limit_info["time_delta_format"][col] = _time_delta_format last_limit_info["reverse"][col] = False last_limit_info["order_price"][col] = _limit_order_price else: main_info["bar_zone"][col] = exec_stop_bar_zone main_info["signal_idx"][col] = exec_stop_init_i main_info["creation_idx"][col] = i main_info["idx"][col] = i main_info["val_price"][col] = exec_stop_val_price main_info["price"][col] = exec_stop_price main_info["size"][col] = exec_stop_size main_info["size_type"][col] = exec_stop_size_type main_info["direction"][col] = exec_stop_direction main_info["type"][col] = exec_stop_type main_info["stop_type"][col] = exec_stop_stop_type if any_stop_signal and not keep_stop: # Clear the pending info any_stop_signal = False last_sl_info["init_idx"][col] = -1 last_sl_info["init_price"][col] = np.nan last_sl_info["init_position"][col] = np.nan last_sl_info["stop"][col] = np.nan last_sl_info["exit_price"][col] = -1 last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = -1 last_sl_info["order_type"][col] = -1 last_sl_info["limit_delta"][col] = np.nan last_sl_info["delta_format"][col] = -1 last_sl_info["ladder"][col] = -1 last_sl_info["step"][col] = -1 last_sl_info["step_idx"][col] = -1 last_tsl_info["init_idx"][col] = -1 last_tsl_info["init_price"][col] = np.nan last_tsl_info["init_position"][col] = np.nan last_tsl_info["peak_idx"][col] = -1 last_tsl_info["peak_price"][col] = np.nan last_tsl_info["stop"][col] = np.nan last_tsl_info["th"][col] = np.nan last_tsl_info["exit_price"][col] = -1 last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = -1 last_tsl_info["order_type"][col] = -1 last_tsl_info["limit_delta"][col] = np.nan last_tsl_info["delta_format"][col] = -1 last_tsl_info["ladder"][col] = -1 last_tsl_info["step"][col] = -1 last_tsl_info["step_idx"][col] = -1 last_tp_info["init_idx"][col] = -1 last_tp_info["init_price"][col] = np.nan last_tp_info["init_position"][col] = np.nan last_tp_info["stop"][col] = np.nan last_tp_info["exit_price"][col] = -1 last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = -1 last_tp_info["order_type"][col] = -1 last_tp_info["limit_delta"][col] = np.nan last_tp_info["delta_format"][col] = -1 last_tp_info["ladder"][col] = -1 last_tp_info["step"][col] = -1 last_tp_info["step_idx"][col] = -1 last_td_info["init_idx"][col] = -1 last_td_info["init_position"][col] = np.nan last_td_info["stop"][col] = -1 last_td_info["exit_price"][col] = -1 last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = -1 last_td_info["order_type"][col] = -1 last_td_info["limit_delta"][col] = np.nan last_td_info["delta_format"][col] = -1 last_td_info["time_delta_format"][col] = -1 last_td_info["ladder"][col] = -1 last_td_info["step"][col] = -1 last_td_info["step_idx"][col] = -1 last_dt_info["init_idx"][col] = -1 last_dt_info["init_position"][col] = np.nan last_dt_info["stop"][col] = -1 last_dt_info["exit_price"][col] = -1 last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = -1 last_dt_info["order_type"][col] = -1 last_dt_info["limit_delta"][col] = np.nan last_dt_info["delta_format"][col] = -1 last_dt_info["time_delta_format"][col] = -1 last_dt_info["ladder"][col] = -1 last_dt_info["step"][col] = -1 last_dt_info["step_idx"][col] = -1 # Process the user signal if execute_user: # Execute the signal if _i >= 0: if exec_user_make_limit: if any_limit_signal: raise ValueError("Only one active limit signal is allowed at a time") _limit_delta = flex_select_nb(limit_delta_, _i, col) _delta_format = flex_select_nb(delta_format_, _i, col) _limit_tif = flex_select_nb(limit_tif_, _i, col) _limit_expiry = flex_select_nb(limit_expiry_, _i, col) _time_delta_format = flex_select_nb(time_delta_format_, _i, col) _limit_reverse = flex_select_nb(limit_reverse_, _i, col) _limit_order_price = flex_select_nb(limit_order_price_, _i, col) last_limit_info["signal_idx"][col] = _i last_limit_info["creation_idx"][col] = i last_limit_info["init_idx"][col] = _i last_limit_info["init_price"][col] = exec_user_price last_limit_info["init_size"][col] = exec_user_size last_limit_info["init_size_type"][col] = exec_user_size_type last_limit_info["init_direction"][col] = exec_user_direction last_limit_info["init_stop_type"][col] = -1 last_limit_info["delta"][col] = _limit_delta last_limit_info["delta_format"][col] = _delta_format last_limit_info["tif"][col] = _limit_tif last_limit_info["expiry"][col] = _limit_expiry last_limit_info["time_delta_format"][col] = _time_delta_format last_limit_info["reverse"][col] = _limit_reverse last_limit_info["order_price"][col] = _limit_order_price else: main_info["bar_zone"][col] = exec_user_bar_zone main_info["signal_idx"][col] = _i main_info["creation_idx"][col] = i main_info["idx"][col] = _i main_info["val_price"][col] = exec_user_val_price main_info["price"][col] = exec_user_price main_info["size"][col] = exec_user_size main_info["size_type"][col] = exec_user_size_type main_info["direction"][col] = exec_user_direction main_info["type"][col] = exec_user_type main_info["stop_type"][col] = exec_user_stop_type skip = skip_empty if skip: for col in range(from_col, to_col): if flex_select_nb(log_, i, col): skip = False break if not np.isnan(main_info["size"][col]): skip = False break if not skip: # Check bar zone and update valuation price bar_zone = -1 same_bar_zone = True same_timing = True for c in range(group_len): col = from_col + c if np.isnan(main_info["size"][col]): continue if bar_zone == -1: bar_zone = main_info["bar_zone"][col] if main_info["bar_zone"][col] != bar_zone: same_bar_zone = False same_timing = False if main_info["bar_zone"][col] == BarZone.Middle: same_timing = False _val_price = main_info["val_price"][col] if not np.isnan(_val_price) or not ffill_val_price: last_val_price[col] = _val_price if cash_sharing: # Dynamically sort by order value -> selling comes first to release funds early if call_seq is None: for c in range(group_len): temp_call_seq[c] = c call_seq_now = temp_call_seq[:group_len] else: call_seq_now = call_seq[i, from_col:to_col] if auto_call_seq: # Sort by order value if not same_timing: raise ValueError("Cannot sort orders by value if they are executed at different times") for c in range(group_len): if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") col = from_col + c if np.isnan(main_info["size"][col]): continue # Approximate order value exec_state = ExecState( cash=last_cash[group] if cash_sharing else last_cash[col], position=last_position[col], debt=last_debt[col], locked_cash=last_locked_cash[col], free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col], val_price=last_val_price[col], value=last_value[group] if cash_sharing else last_value[col], ) temp_sort_by[c] = approx_order_value_nb( exec_state=exec_state, size=main_info["size"][col], size_type=main_info["size_type"][col], direction=main_info["direction"][col], ) insert_argsort_nb(temp_sort_by[:group_len], call_seq_now) else: if not same_bar_zone: # Sort by bar zone for c in range(group_len): if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") col = from_col + c if np.isnan(main_info["size"][col]): continue temp_sort_by[c] = main_info["bar_zone"][col] insert_argsort_nb(temp_sort_by[:group_len], call_seq_now) for k in range(group_len): if cash_sharing: c = call_seq_now[k] if c >= group_len: raise ValueError("Call index out of bounds of the group") else: c = k col = from_col + c if skip_empty and np.isnan(main_info["size"][col]): # shortcut continue # Get current values per column position_now = last_position[col] debt_now = last_debt[col] locked_cash_now = last_locked_cash[col] val_price_now = last_val_price[col] cash_now = last_cash[group] free_cash_now = last_free_cash[group] value_now = last_value[group] return_now = last_return[group] # Generate the next order _i = main_info["idx"][col] if main_info["type"][col] == OrderType.Limit: _slippage = 0.0 else: _slippage = float(flex_select_nb(slippage_, _i, col)) _min_size = flex_select_nb(min_size_, _i, col) _max_size = flex_select_nb(max_size_, _i, col) _size_type = flex_select_nb(size_type_, _i, col) if _size_type != main_info["size_type"][col]: if not np.isnan(_min_size): _min_size, _ = resolve_size_nb( size=_min_size, size_type=_size_type, position=position_now, val_price=val_price_now, value=value_now, target_size_type=main_info["size_type"][col], as_requirement=True, ) if not np.isnan(_max_size): _max_size, _ = resolve_size_nb( size=_max_size, size_type=_size_type, position=position_now, val_price=val_price_now, value=value_now, target_size_type=main_info["size_type"][col], as_requirement=True, ) order = order_nb( size=main_info["size"][col], price=main_info["price"][col], size_type=main_info["size_type"][col], direction=main_info["direction"][col], fees=flex_select_nb(fees_, _i, col), fixed_fees=flex_select_nb(fixed_fees_, _i, col), slippage=_slippage, min_size=_min_size, max_size=_max_size, size_granularity=flex_select_nb(size_granularity_, _i, col), leverage=flex_select_nb(leverage_, _i, col), leverage_mode=flex_select_nb(leverage_mode_, _i, col), reject_prob=flex_select_nb(reject_prob_, _i, col), price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col), allow_partial=flex_select_nb(allow_partial_, _i, col), raise_reject=flex_select_nb(raise_reject_, _i, col), log=flex_select_nb(log_, _i, col), ) # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Append more order information if order_result.status == OrderStatus.Filled and order_counts[col] >= 1: order_records["signal_idx"][order_counts[col] - 1, col] = main_info["signal_idx"][col] order_records["creation_idx"][order_counts[col] - 1, col] = main_info["creation_idx"][col] order_records["type"][order_counts[col] - 1, col] = main_info["type"][col] order_records["stop_type"][order_counts[col] - 1, col] = main_info["stop_type"][col] # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash val_price_now = new_exec_state.val_price value_now = new_exec_state.value if use_stops: # Update stop price if position_now == 0: # Not in position anymore -> clear stops (irrespective of order success) last_sl_info["init_idx"][col] = -1 last_sl_info["init_price"][col] = np.nan last_sl_info["init_position"][col] = np.nan last_sl_info["stop"][col] = np.nan last_sl_info["exit_price"][col] = -1 last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = -1 last_sl_info["order_type"][col] = -1 last_sl_info["limit_delta"][col] = np.nan last_sl_info["delta_format"][col] = -1 last_sl_info["ladder"][col] = -1 last_sl_info["step"][col] = -1 last_sl_info["step_idx"][col] = -1 last_tsl_info["init_idx"][col] = -1 last_tsl_info["init_price"][col] = np.nan last_tsl_info["init_position"][col] = np.nan last_tsl_info["peak_idx"][col] = -1 last_tsl_info["peak_price"][col] = np.nan last_tsl_info["stop"][col] = np.nan last_tsl_info["th"][col] = np.nan last_tsl_info["exit_price"][col] = -1 last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = -1 last_tsl_info["order_type"][col] = -1 last_tsl_info["limit_delta"][col] = np.nan last_tsl_info["delta_format"][col] = -1 last_tsl_info["ladder"][col] = -1 last_tsl_info["step"][col] = -1 last_tsl_info["step_idx"][col] = -1 last_tp_info["init_idx"][col] = -1 last_tp_info["init_price"][col] = np.nan last_tp_info["init_position"][col] = np.nan last_tp_info["stop"][col] = np.nan last_tp_info["exit_price"][col] = -1 last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = -1 last_tp_info["order_type"][col] = -1 last_tp_info["limit_delta"][col] = np.nan last_tp_info["delta_format"][col] = -1 last_tp_info["ladder"][col] = -1 last_tp_info["step"][col] = -1 last_tp_info["step_idx"][col] = -1 last_td_info["init_idx"][col] = -1 last_td_info["init_position"][col] = np.nan last_td_info["stop"][col] = -1 last_td_info["exit_price"][col] = -1 last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = -1 last_td_info["order_type"][col] = -1 last_td_info["limit_delta"][col] = np.nan last_td_info["delta_format"][col] = -1 last_td_info["time_delta_format"][col] = -1 last_td_info["ladder"][col] = -1 last_td_info["step"][col] = -1 last_td_info["step_idx"][col] = -1 last_dt_info["init_idx"][col] = -1 last_dt_info["init_position"][col] = np.nan last_dt_info["stop"][col] = -1 last_dt_info["exit_price"][col] = -1 last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = -1 last_dt_info["order_type"][col] = -1 last_dt_info["limit_delta"][col] = np.nan last_dt_info["delta_format"][col] = -1 last_dt_info["time_delta_format"][col] = -1 last_dt_info["ladder"][col] = -1 last_dt_info["step"][col] = -1 last_dt_info["step_idx"][col] = -1 else: if main_info["stop_type"][col] == StopType.SL: if last_sl_info["ladder"][col]: step = last_sl_info["step"][col] + 1 last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 if stop_ladder and last_sl_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_sl_steps: last_sl_info["stop"][col] = flex_select_nb(sl_stop_, step, col) last_sl_info["step"][col] = step last_sl_info["step_idx"][col] = i else: last_sl_info["stop"][col] = np.nan last_sl_info["step"][col] = -1 last_sl_info["step_idx"][col] = -1 else: last_sl_info["stop"][col] = np.nan last_sl_info["step"][col] = step last_sl_info["step_idx"][col] = i elif ( main_info["stop_type"][col] == StopType.TSL or main_info["stop_type"][col] == StopType.TTP ): if last_tsl_info["ladder"][col]: step = last_tsl_info["step"][col] + 1 last_tsl_info["step"][col] = step last_tsl_info["step_idx"][col] = i last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 if stop_ladder and last_tsl_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_tsl_steps: last_tsl_info["stop"][col] = flex_select_nb(tsl_stop_, step, col) last_tsl_info["step"][col] = step last_tsl_info["step_idx"][col] = i else: last_tsl_info["stop"][col] = np.nan last_tsl_info["step"][col] = -1 last_tsl_info["step_idx"][col] = -1 else: last_tsl_info["stop"][col] = np.nan last_tsl_info["step"][col] = step last_tsl_info["step_idx"][col] = i elif main_info["stop_type"][col] == StopType.TP: if last_tp_info["ladder"][col]: step = last_tp_info["step"][col] + 1 last_tp_info["step"][col] = step last_tp_info["step_idx"][col] = i last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 if stop_ladder and last_tp_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_tp_steps: last_tp_info["stop"][col] = flex_select_nb(tp_stop_, step, col) last_tp_info["step"][col] = step last_tp_info["step_idx"][col] = i else: last_tp_info["stop"][col] = np.nan last_tp_info["step"][col] = -1 last_tp_info["step_idx"][col] = -1 else: last_tp_info["stop"][col] = np.nan last_tp_info["step"][col] = step last_tp_info["step_idx"][col] = i elif main_info["stop_type"][col] == StopType.TD: if last_td_info["ladder"][col]: step = last_td_info["step"][col] + 1 last_td_info["step"][col] = step last_td_info["step_idx"][col] = i last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 if stop_ladder and last_td_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_td_steps: last_td_info["stop"][col] = flex_select_nb(td_stop_, step, col) last_td_info["step"][col] = step last_td_info["step_idx"][col] = i else: last_td_info["stop"][col] = -1 last_td_info["step"][col] = -1 last_td_info["step_idx"][col] = -1 else: last_td_info["stop"][col] = -1 last_td_info["step"][col] = step last_td_info["step_idx"][col] = i elif main_info["stop_type"][col] == StopType.DT: if last_dt_info["ladder"][col]: step = last_dt_info["step"][col] + 1 last_dt_info["step"][col] = step last_dt_info["step_idx"][col] = i last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 if stop_ladder and last_dt_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_dt_steps: last_dt_info["stop"][col] = flex_select_nb(dt_stop_, step, col) last_dt_info["step"][col] = step last_dt_info["step_idx"][col] = i else: last_dt_info["stop"][col] = -1 last_dt_info["step"][col] = -1 last_dt_info["step_idx"][col] = -1 else: last_dt_info["stop"][col] = -1 last_dt_info["step"][col] = step last_dt_info["step_idx"][col] = i if order_result.status == OrderStatus.Filled and position_now != 0: # Order filled and in position -> possibly set stops _price = main_info["price"][col] _stop_entry_price = flex_select_nb(stop_entry_price_, i, col) if _stop_entry_price < 0: if _stop_entry_price == StopEntryPrice.ValPrice: new_init_price = val_price_now can_use_ohlc = False elif _stop_entry_price == StopEntryPrice.Price: new_init_price = order.price can_use_ohlc = np.isinf(_price) and _price < 0 if np.isinf(new_init_price): if new_init_price > 0: new_init_price = flex_select_nb(close_, i, col) else: new_init_price = flex_select_nb(open_, i, col) elif _stop_entry_price == StopEntryPrice.FillPrice: new_init_price = order_result.price can_use_ohlc = np.isinf(_price) and _price < 0 elif _stop_entry_price == StopEntryPrice.Open: new_init_price = flex_select_nb(open_, i, col) can_use_ohlc = True elif _stop_entry_price == StopEntryPrice.Close: new_init_price = flex_select_nb(close_, i, col) can_use_ohlc = False else: raise ValueError("Invalid StopEntryPrice option") else: new_init_price = _stop_entry_price can_use_ohlc = False if stop_ladder: _sl_stop = flex_select_nb(sl_stop_, 0, col) _tsl_stop = flex_select_nb(tsl_stop_, 0, col) _tp_stop = flex_select_nb(tp_stop_, 0, col) _td_stop = flex_select_nb(td_stop_, 0, col) _dt_stop = flex_select_nb(dt_stop_, 0, col) else: _sl_stop = flex_select_nb(sl_stop_, i, col) _tsl_stop = flex_select_nb(tsl_stop_, i, col) _tp_stop = flex_select_nb(tp_stop_, i, col) _td_stop = flex_select_nb(td_stop_, i, col) _dt_stop = flex_select_nb(dt_stop_, i, col) _tsl_th = flex_select_nb(tsl_th_, i, col) _stop_exit_price = flex_select_nb(stop_exit_price_, i, col) _stop_exit_type = flex_select_nb(stop_exit_type_, i, col) _stop_order_type = flex_select_nb(stop_order_type_, i, col) _stop_limit_delta = flex_select_nb(stop_limit_delta_, i, col) _delta_format = flex_select_nb(delta_format_, i, col) _time_delta_format = flex_select_nb(time_delta_format_, i, col) tsl_updated = False if exec_state.position == 0 or np.sign(position_now) != np.sign(exec_state.position): # Position opened/reversed -> set stops last_sl_info["init_idx"][col] = i last_sl_info["init_price"][col] = new_init_price last_sl_info["init_position"][col] = position_now last_sl_info["stop"][col] = _sl_stop last_sl_info["exit_price"][col] = _stop_exit_price last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = _stop_exit_type last_sl_info["order_type"][col] = _stop_order_type last_sl_info["limit_delta"][col] = _stop_limit_delta last_sl_info["delta_format"][col] = _delta_format last_sl_info["ladder"][col] = stop_ladder last_sl_info["step"][col] = 0 last_sl_info["step_idx"][col] = i tsl_updated = True last_tsl_info["init_idx"][col] = i last_tsl_info["init_price"][col] = new_init_price last_tsl_info["init_position"][col] = position_now last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = new_init_price last_tsl_info["stop"][col] = _tsl_stop last_tsl_info["th"][col] = _tsl_th last_tsl_info["exit_price"][col] = _stop_exit_price last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = _stop_exit_type last_tsl_info["order_type"][col] = _stop_order_type last_tsl_info["limit_delta"][col] = _stop_limit_delta last_tsl_info["delta_format"][col] = _delta_format last_tsl_info["ladder"][col] = stop_ladder last_tsl_info["step"][col] = 0 last_tsl_info["step_idx"][col] = i last_tp_info["init_idx"][col] = i last_tp_info["init_price"][col] = new_init_price last_tp_info["init_position"][col] = position_now last_tp_info["stop"][col] = _tp_stop last_tp_info["exit_price"][col] = _stop_exit_price last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = _stop_exit_type last_tp_info["order_type"][col] = _stop_order_type last_tp_info["limit_delta"][col] = _stop_limit_delta last_tp_info["delta_format"][col] = _delta_format last_tp_info["ladder"][col] = stop_ladder last_tp_info["step"][col] = 0 last_tp_info["step_idx"][col] = i last_td_info["init_idx"][col] = i last_td_info["init_position"][col] = position_now last_td_info["stop"][col] = _td_stop last_td_info["exit_price"][col] = _stop_exit_price last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = _stop_exit_type last_td_info["order_type"][col] = _stop_order_type last_td_info["limit_delta"][col] = _stop_limit_delta last_td_info["delta_format"][col] = _delta_format last_td_info["time_delta_format"][col] = _time_delta_format last_td_info["ladder"][col] = stop_ladder last_td_info["step"][col] = 0 last_td_info["step_idx"][col] = i last_dt_info["init_idx"][col] = i last_dt_info["init_position"][col] = position_now last_dt_info["stop"][col] = _dt_stop last_dt_info["exit_price"][col] = _stop_exit_price last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = _stop_exit_type last_dt_info["order_type"][col] = _stop_order_type last_dt_info["limit_delta"][col] = _stop_limit_delta last_dt_info["delta_format"][col] = _delta_format last_dt_info["time_delta_format"][col] = _time_delta_format last_dt_info["ladder"][col] = stop_ladder last_dt_info["step"][col] = 0 last_dt_info["step_idx"][col] = i elif abs(position_now) > abs(exec_state.position): # Position increased -> keep/override stops _upon_stop_update = flex_select_nb(upon_stop_update_, i, col) if should_update_stop_nb(new_stop=_sl_stop, upon_stop_update=_upon_stop_update): last_sl_info["init_idx"][col] = i last_sl_info["init_price"][col] = new_init_price last_sl_info["init_position"][col] = position_now last_sl_info["stop"][col] = _sl_stop last_sl_info["exit_price"][col] = _stop_exit_price last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = _stop_exit_type last_sl_info["order_type"][col] = _stop_order_type last_sl_info["limit_delta"][col] = _stop_limit_delta last_sl_info["delta_format"][col] = _delta_format last_sl_info["ladder"][col] = stop_ladder last_sl_info["step"][col] = 0 last_sl_info["step_idx"][col] = i if should_update_stop_nb(new_stop=_tsl_stop, upon_stop_update=_upon_stop_update): tsl_updated = True last_tsl_info["init_idx"][col] = i last_tsl_info["init_price"][col] = new_init_price last_tsl_info["init_position"][col] = position_now last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = new_init_price last_tsl_info["stop"][col] = _tsl_stop last_tsl_info["th"][col] = _tsl_th last_tsl_info["exit_price"][col] = _stop_exit_price last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = _stop_exit_type last_tsl_info["order_type"][col] = _stop_order_type last_tsl_info["limit_delta"][col] = _stop_limit_delta last_tsl_info["delta_format"][col] = _delta_format last_tsl_info["ladder"][col] = stop_ladder last_tsl_info["step"][col] = 0 last_tsl_info["step_idx"][col] = i if should_update_stop_nb(new_stop=_tp_stop, upon_stop_update=_upon_stop_update): last_tp_info["init_idx"][col] = i last_tp_info["init_price"][col] = new_init_price last_tp_info["init_position"][col] = position_now last_tp_info["stop"][col] = _tp_stop last_tp_info["exit_price"][col] = _stop_exit_price last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = _stop_exit_type last_tp_info["order_type"][col] = _stop_order_type last_tp_info["limit_delta"][col] = _stop_limit_delta last_tp_info["delta_format"][col] = _delta_format last_tp_info["ladder"][col] = stop_ladder last_tp_info["step"][col] = 0 last_tp_info["step_idx"][col] = i if should_update_time_stop_nb( new_stop=_td_stop, upon_stop_update=_upon_stop_update ): last_td_info["init_idx"][col] = i last_td_info["init_position"][col] = position_now last_td_info["stop"][col] = _td_stop last_td_info["exit_price"][col] = _stop_exit_price last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = _stop_exit_type last_td_info["order_type"][col] = _stop_order_type last_td_info["limit_delta"][col] = _stop_limit_delta last_td_info["delta_format"][col] = _delta_format last_td_info["time_delta_format"][col] = _time_delta_format last_td_info["ladder"][col] = stop_ladder last_td_info["step"][col] = 0 last_td_info["step_idx"][col] = i if should_update_time_stop_nb( new_stop=_dt_stop, upon_stop_update=_upon_stop_update ): last_dt_info["init_idx"][col] = i last_dt_info["init_position"][col] = position_now last_dt_info["stop"][col] = _dt_stop last_dt_info["exit_price"][col] = _stop_exit_price last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = _stop_exit_type last_dt_info["order_type"][col] = _stop_order_type last_dt_info["limit_delta"][col] = _stop_limit_delta last_dt_info["delta_format"][col] = _delta_format last_dt_info["time_delta_format"][col] = _time_delta_format last_dt_info["ladder"][col] = stop_ladder last_dt_info["step"][col] = 0 last_dt_info["step_idx"][col] = i if tsl_updated: # Update highest/lowest price if can_use_ohlc: _open = flex_select_nb(open_, i, col) _high = flex_select_nb(high_, i, col) _low = flex_select_nb(low_, i, col) _close = flex_select_nb(close_, i, col) _high, _low = resolve_hl_nb( open=_open, high=_high, low=_low, close=_close, ) else: _open = np.nan _high = _low = _close = flex_select_nb(close_, i, col) if tsl_updated: if position_now > 0: if _high > last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _high - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _high elif position_now < 0: if _low < last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _low - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _low # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now last_cash[group] = cash_now last_free_cash[group] = free_cash_now last_value[group] = value_now last_return[group] = return_now for col in range(from_col, to_col): # Update valuation price using current close _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close _cash_earnings = flex_select_nb(cash_earnings_, i, col) _cash_dividends = flex_select_nb(cash_dividends_, i, col) _cash_earnings += _cash_dividends * last_position[col] last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings if track_cash_earnings: cash_earnings_out[i, col] += _cash_earnings if save_state: position[i, col] = last_position[col] debt[i, col] = last_debt[col] locked_cash[i, col] = last_locked_cash[col] cash[i, group] = last_cash[group] free_cash[i, group] = last_free_cash[group] # Update value and return group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - _cash_deposits, ) prev_close_value[group] = last_value[group] if save_value: in_outputs.value[i, group] = last_value[group] if save_returns: in_outputs.returns[i, group] = last_return[group] sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_out, cash_earnings=cash_earnings_out, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) @register_jitted(cache=True) def init_FSInOutputs_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool = False, save_state: bool = True, save_value: bool = True, save_returns: bool = True, ): """Initialize `vectorbtpro.portfolio.enums.FSInOutputs`.""" if save_state: position = np.full(target_shape, np.nan) debt = np.full(target_shape, np.nan) locked_cash = np.full(target_shape, np.nan) if cash_sharing: cash = np.full((target_shape[0], len(group_lens)), np.nan) free_cash = np.full((target_shape[0], len(group_lens)), np.nan) else: cash = np.full(target_shape, np.nan) free_cash = np.full(target_shape, np.nan) else: position = np.full((0, 0), np.nan) debt = np.full((0, 0), np.nan) locked_cash = np.full((0, 0), np.nan) cash = np.full((0, 0), np.nan) free_cash = np.full((0, 0), np.nan) if save_value: if cash_sharing: value = np.full((target_shape[0], len(group_lens)), np.nan) else: value = np.full(target_shape, np.nan) else: value = np.full((0, 0), np.nan) if save_returns: if cash_sharing: returns = np.full((target_shape[0], len(group_lens)), np.nan) else: returns = np.full(target_shape, np.nan) else: returns = np.full((0, 0), np.nan) return FSInOutputs( position=position, debt=debt, locked_cash=locked_cash, cash=cash, free_cash=free_cash, value=value, returns=returns, ) @register_jitted def no_signal_func_nb(c: SignalContext, *args) -> tp.Tuple[bool, bool, bool, bool]: """Placeholder `signal_func_nb` that returns no signal.""" return False, False, False, False SignalFuncT = tp.Callable[[SignalContext, tp.VarArg()], tp.Tuple[bool, bool, bool, bool]] PostSignalFuncT = tp.Callable[[PostSignalContext, tp.VarArg()], None] PostSegmentFuncT = tp.Callable[[SignalSegmentContext, tp.VarArg()], None] # % # % # % # @register_jitted # def signal_func_nb( # c: SignalContext, # ) -> tp.Tuple[bool, bool, bool, bool]: # """Custom signal function.""" # return False, False, False, False # # # % # % # % # % # % # % # @register_jitted # def post_signal_func_nb( # c: PostSignalContext, # ) -> None: # """Custom post-signal function.""" # return None # # # % # % # % # % # % # % # @register_jitted # def post_segment_func_nb( # c: SignalSegmentContext, # ) -> None: # """Custom post-segment function.""" # return None # # # % # % # % # %
# % # import vectorbtpro as vbt # from vectorbtpro.portfolio.nb.from_signals import * # %? import_lines # # # % # %? blocks[signal_func_nb_block] # % blocks["signal_func_nb"] # %? blocks[post_signal_func_nb_block] # % blocks["post_signal_func_nb"] # %? blocks[post_segment_func_nb_block] # % blocks["post_segment_func_nb"] @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( target_shape=base_ch.shape_gl_slicer, group_lens=ch.ArraySlicer(axis=0), cash_sharing=None, index=None, freq=None, open=base_ch.flex_array_gl_slicer, high=base_ch.flex_array_gl_slicer, low=base_ch.flex_array_gl_slicer, close=base_ch.flex_array_gl_slicer, init_cash=RepFunc(portfolio_ch.get_init_cash_slicer), init_position=base_ch.flex_1d_array_gl_slicer, init_price=base_ch.flex_1d_array_gl_slicer, cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer), cash_earnings=base_ch.flex_array_gl_slicer, cash_dividends=base_ch.flex_array_gl_slicer, signal_func_nb=None, # % None signal_args=ch.ArgsTaker(), post_signal_func_nb=None, # % None post_signal_args=ch.ArgsTaker(), post_segment_func_nb=None, # % None post_segment_args=ch.ArgsTaker(), size=base_ch.flex_array_gl_slicer, price=base_ch.flex_array_gl_slicer, size_type=base_ch.flex_array_gl_slicer, fees=base_ch.flex_array_gl_slicer, fixed_fees=base_ch.flex_array_gl_slicer, slippage=base_ch.flex_array_gl_slicer, min_size=base_ch.flex_array_gl_slicer, max_size=base_ch.flex_array_gl_slicer, size_granularity=base_ch.flex_array_gl_slicer, leverage=base_ch.flex_array_gl_slicer, leverage_mode=base_ch.flex_array_gl_slicer, reject_prob=base_ch.flex_array_gl_slicer, price_area_vio_mode=base_ch.flex_array_gl_slicer, allow_partial=base_ch.flex_array_gl_slicer, raise_reject=base_ch.flex_array_gl_slicer, log=base_ch.flex_array_gl_slicer, val_price=base_ch.flex_array_gl_slicer, accumulate=base_ch.flex_array_gl_slicer, upon_long_conflict=base_ch.flex_array_gl_slicer, upon_short_conflict=base_ch.flex_array_gl_slicer, upon_dir_conflict=base_ch.flex_array_gl_slicer, upon_opposite_entry=base_ch.flex_array_gl_slicer, order_type=base_ch.flex_array_gl_slicer, limit_delta=base_ch.flex_array_gl_slicer, limit_tif=base_ch.flex_array_gl_slicer, limit_expiry=base_ch.flex_array_gl_slicer, limit_reverse=base_ch.flex_array_gl_slicer, limit_order_price=base_ch.flex_array_gl_slicer, upon_adj_limit_conflict=base_ch.flex_array_gl_slicer, upon_opp_limit_conflict=base_ch.flex_array_gl_slicer, use_stops=None, stop_ladder=None, sl_stop=base_ch.flex_array_gl_slicer, tsl_stop=base_ch.flex_array_gl_slicer, tsl_th=base_ch.flex_array_gl_slicer, tp_stop=base_ch.flex_array_gl_slicer, td_stop=base_ch.flex_array_gl_slicer, dt_stop=base_ch.flex_array_gl_slicer, stop_entry_price=base_ch.flex_array_gl_slicer, stop_exit_price=base_ch.flex_array_gl_slicer, stop_exit_type=base_ch.flex_array_gl_slicer, stop_order_type=base_ch.flex_array_gl_slicer, stop_limit_delta=base_ch.flex_array_gl_slicer, upon_stop_update=base_ch.flex_array_gl_slicer, upon_adj_stop_conflict=base_ch.flex_array_gl_slicer, upon_opp_stop_conflict=base_ch.flex_array_gl_slicer, delta_format=base_ch.flex_array_gl_slicer, time_delta_format=base_ch.flex_array_gl_slicer, from_ago=base_ch.flex_array_gl_slicer, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), call_seq=base_ch.array_gl_slicer, auto_call_seq=None, ffill_val_price=None, update_value=None, fill_pos_info=None, skip_empty=None, max_order_records=None, max_log_records=None, in_outputs=ch.ArgsTaker(), ), **portfolio_ch.merge_sim_outs_config, setup_id=None, # %? line.replace("None", task_id) ) @register_jitted( tags={"can_parallel"}, cache=False, # % line.replace("False", "True") task_id_or_func=None, # %? line.replace("None", task_id) ) def from_signal_func_nb( # %? line.replace("from_signal_func_nb", new_func_name) target_shape: tp.Shape, group_lens: tp.GroupLens, cash_sharing: bool, index: tp.Optional[tp.Array1d] = None, freq: tp.Optional[int] = None, open: tp.FlexArray2dLike = np.nan, high: tp.FlexArray2dLike = np.nan, low: tp.FlexArray2dLike = np.nan, close: tp.FlexArray2dLike = np.nan, init_cash: tp.FlexArray1dLike = 100.0, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, cash_deposits: tp.FlexArray2dLike = 0.0, cash_earnings: tp.FlexArray2dLike = 0.0, cash_dividends: tp.FlexArray2dLike = 0.0, signal_func_nb: SignalFuncT = no_signal_func_nb, # % None signal_args: tp.ArgsLike = (), post_signal_func_nb: PostSignalFuncT = no_post_func_nb, # % None post_signal_args: tp.ArgsLike = (), post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None post_segment_args: tp.ArgsLike = (), size: tp.FlexArray2dLike = np.inf, price: tp.FlexArray2dLike = np.inf, size_type: tp.FlexArray2dLike = SizeType.Amount, fees: tp.FlexArray2dLike = 0.0, fixed_fees: tp.FlexArray2dLike = 0.0, slippage: tp.FlexArray2dLike = 0.0, min_size: tp.FlexArray2dLike = np.nan, max_size: tp.FlexArray2dLike = np.nan, size_granularity: tp.FlexArray2dLike = np.nan, leverage: tp.FlexArray2dLike = 1.0, leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy, reject_prob: tp.FlexArray2dLike = 0.0, price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore, allow_partial: tp.FlexArray2dLike = True, raise_reject: tp.FlexArray2dLike = False, log: tp.FlexArray2dLike = False, val_price: tp.FlexArray2dLike = np.inf, accumulate: tp.FlexArray2dLike = AccumulationMode.Disabled, upon_long_conflict: tp.FlexArray2dLike = ConflictMode.Ignore, upon_short_conflict: tp.FlexArray2dLike = ConflictMode.Ignore, upon_dir_conflict: tp.FlexArray2dLike = DirectionConflictMode.Ignore, upon_opposite_entry: tp.FlexArray2dLike = OppositeEntryMode.ReverseReduce, order_type: tp.FlexArray2dLike = OrderType.Market, limit_delta: tp.FlexArray2dLike = np.nan, limit_tif: tp.FlexArray2dLike = -1, limit_expiry: tp.FlexArray2dLike = -1, limit_reverse: tp.FlexArray2dLike = False, limit_order_price: tp.FlexArray2dLike = LimitOrderPrice.Limit, upon_adj_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepIgnore, upon_opp_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.CancelExecute, use_stops: bool = True, stop_ladder: int = StopLadderMode.Disabled, sl_stop: tp.FlexArray2dLike = np.nan, tsl_stop: tp.FlexArray2dLike = np.nan, tsl_th: tp.FlexArray2dLike = np.nan, tp_stop: tp.FlexArray2dLike = np.nan, td_stop: tp.FlexArray2dLike = -1, dt_stop: tp.FlexArray2dLike = -1, stop_entry_price: tp.FlexArray2dLike = StopEntryPrice.Close, stop_exit_price: tp.FlexArray2dLike = StopExitPrice.Stop, stop_exit_type: tp.FlexArray2dLike = StopExitType.Close, stop_order_type: tp.FlexArray2dLike = OrderType.Market, stop_limit_delta: tp.FlexArray2dLike = np.nan, upon_stop_update: tp.FlexArray2dLike = StopUpdateMode.Keep, upon_adj_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute, upon_opp_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute, delta_format: tp.FlexArray2dLike = DeltaFormat.Percent, time_delta_format: tp.FlexArray2dLike = TimeDeltaFormat.Index, from_ago: tp.FlexArray2dLike = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, call_seq: tp.Optional[tp.Array2d] = None, auto_call_seq: bool = False, ffill_val_price: bool = True, update_value: bool = False, fill_pos_info: bool = True, skip_empty: bool = True, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = 0, in_outputs: tp.Optional[tp.NamedTuple] = None, ) -> SimulationOutput: """Simulate given a signal function. Iterates in the column-major order. Utilizes flexible broadcasting. `signal_func_nb` is a user-defined signal generation function that is called at each row and column (= element). It must accept the context of the type `vectorbtpro.portfolio.enums.SignalContext` and return 4 signals: long entry, long exit, short entry, and short exit. `post_signal_func_nb` is a user-defined post-signal function that is called after an order has been processed. It must accept the context of the type `vectorbtpro.portfolio.enums.PostSignalContext` and return nothing. `post_segment_func_nb` is a user-defined post-segment function that is called after each row and group (= segment). It must accept the context of the type `vectorbtpro.portfolio.enums.SignalSegmentContext` and return nothing. """ check_group_lens_nb(group_lens, target_shape[1]) open_ = to_2d_array_nb(np.asarray(open)) high_ = to_2d_array_nb(np.asarray(high)) low_ = to_2d_array_nb(np.asarray(low)) close_ = to_2d_array_nb(np.asarray(close)) init_cash_ = to_1d_array_nb(np.asarray(init_cash)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits)) cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings)) cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends)) size_ = to_2d_array_nb(np.asarray(size)) price_ = to_2d_array_nb(np.asarray(price)) size_type_ = to_2d_array_nb(np.asarray(size_type)) fees_ = to_2d_array_nb(np.asarray(fees)) fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees)) slippage_ = to_2d_array_nb(np.asarray(slippage)) min_size_ = to_2d_array_nb(np.asarray(min_size)) max_size_ = to_2d_array_nb(np.asarray(max_size)) size_granularity_ = to_2d_array_nb(np.asarray(size_granularity)) leverage_ = to_2d_array_nb(np.asarray(leverage)) leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode)) reject_prob_ = to_2d_array_nb(np.asarray(reject_prob)) price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode)) allow_partial_ = to_2d_array_nb(np.asarray(allow_partial)) raise_reject_ = to_2d_array_nb(np.asarray(raise_reject)) log_ = to_2d_array_nb(np.asarray(log)) val_price_ = to_2d_array_nb(np.asarray(val_price)) accumulate_ = to_2d_array_nb(np.asarray(accumulate)) upon_long_conflict_ = to_2d_array_nb(np.asarray(upon_long_conflict)) upon_short_conflict_ = to_2d_array_nb(np.asarray(upon_short_conflict)) upon_dir_conflict_ = to_2d_array_nb(np.asarray(upon_dir_conflict)) upon_opposite_entry_ = to_2d_array_nb(np.asarray(upon_opposite_entry)) order_type_ = to_2d_array_nb(np.asarray(order_type)) limit_delta_ = to_2d_array_nb(np.asarray(limit_delta)) limit_tif_ = to_2d_array_nb(np.asarray(limit_tif)) limit_expiry_ = to_2d_array_nb(np.asarray(limit_expiry)) limit_reverse_ = to_2d_array_nb(np.asarray(limit_reverse)) limit_order_price_ = to_2d_array_nb(np.asarray(limit_order_price)) upon_adj_limit_conflict_ = to_2d_array_nb(np.asarray(upon_adj_limit_conflict)) upon_opp_limit_conflict_ = to_2d_array_nb(np.asarray(upon_opp_limit_conflict)) sl_stop_ = to_2d_array_nb(np.asarray(sl_stop)) tsl_stop_ = to_2d_array_nb(np.asarray(tsl_stop)) tsl_th_ = to_2d_array_nb(np.asarray(tsl_th)) tp_stop_ = to_2d_array_nb(np.asarray(tp_stop)) td_stop_ = to_2d_array_nb(np.asarray(td_stop)) dt_stop_ = to_2d_array_nb(np.asarray(dt_stop)) stop_entry_price_ = to_2d_array_nb(np.asarray(stop_entry_price)) stop_exit_price_ = to_2d_array_nb(np.asarray(stop_exit_price)) stop_exit_type_ = to_2d_array_nb(np.asarray(stop_exit_type)) stop_order_type_ = to_2d_array_nb(np.asarray(stop_order_type)) stop_limit_delta_ = to_2d_array_nb(np.asarray(stop_limit_delta)) upon_stop_update_ = to_2d_array_nb(np.asarray(upon_stop_update)) upon_adj_stop_conflict_ = to_2d_array_nb(np.asarray(upon_adj_stop_conflict)) upon_opp_stop_conflict_ = to_2d_array_nb(np.asarray(upon_opp_stop_conflict)) delta_format_ = to_2d_array_nb(np.asarray(delta_format)) time_delta_format_ = to_2d_array_nb(np.asarray(time_delta_format)) from_ago_ = to_2d_array_nb(np.asarray(from_ago)) n_sl_steps = sl_stop_.shape[0] n_tsl_steps = tsl_stop_.shape[0] n_tp_steps = tp_stop_.shape[0] n_td_steps = td_stop_.shape[0] n_dt_steps = dt_stop_.shape[0] order_records, log_records = prepare_fs_records_nb( target_shape=target_shape, max_order_records=max_order_records, max_log_records=max_log_records, ) order_counts = np.full(target_shape[1], 0, dtype=int_) log_counts = np.full(target_shape[1], 0, dtype=int_) last_cash = prepare_last_cash_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, ) last_position = prepare_last_position_nb( target_shape=target_shape, init_position=init_position_, ) last_value = prepare_last_value_nb( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, ) last_pos_info = prepare_last_pos_info_nb( target_shape, init_position=init_position_, init_price=init_price_, fill_pos_info=fill_pos_info, ) last_cash_deposits = np.full_like(last_cash, 0.0) last_val_price = np.full_like(last_position, np.nan) last_debt = np.full(target_shape[1], 0.0, dtype=float_) last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_) last_free_cash = last_cash.copy() prev_close_value = last_value.copy() last_return = np.full_like(last_cash, np.nan) track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1 if track_cash_deposits: cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_) else: cash_deposits_out = np.full((1, 1), 0.0, dtype=float_) track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1 track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1 track_cash_earnings = track_cash_earnings or track_cash_dividends if track_cash_earnings: cash_earnings_out = np.full(target_shape, 0.0, dtype=float_) else: cash_earnings_out = np.full((1, 1), 0.0, dtype=float_) last_limit_info = np.empty(target_shape[1], dtype=limit_info_dt) last_limit_info["signal_idx"][:] = -1 last_limit_info["creation_idx"][:] = -1 last_limit_info["init_idx"][:] = -1 last_limit_info["init_price"][:] = np.nan last_limit_info["init_size"][:] = np.nan last_limit_info["init_size_type"][:] = -1 last_limit_info["init_direction"][:] = -1 last_limit_info["init_stop_type"][:] = -1 last_limit_info["delta"][:] = np.nan last_limit_info["delta_format"][:] = -1 last_limit_info["tif"][:] = -1 last_limit_info["expiry"][:] = -1 last_limit_info["time_delta_format"][:] = -1 last_limit_info["reverse"][:] = False last_limit_info["order_price"][:] = np.nan if use_stops: last_sl_info = np.empty(target_shape[1], dtype=sl_info_dt) last_sl_info["init_idx"][:] = -1 last_sl_info["init_price"][:] = np.nan last_sl_info["init_position"][:] = np.nan last_sl_info["stop"][:] = np.nan last_sl_info["exit_price"][:] = -1 last_sl_info["exit_size"][:] = np.nan last_sl_info["exit_size_type"][:] = -1 last_sl_info["exit_type"][:] = -1 last_sl_info["order_type"][:] = -1 last_sl_info["limit_delta"][:] = np.nan last_sl_info["delta_format"][:] = -1 last_sl_info["ladder"][:] = -1 last_sl_info["step"][:] = -1 last_sl_info["step_idx"][:] = -1 last_tsl_info = np.empty(target_shape[1], dtype=tsl_info_dt) last_tsl_info["init_idx"][:] = -1 last_tsl_info["init_price"][:] = np.nan last_tsl_info["init_position"][:] = np.nan last_tsl_info["peak_idx"][:] = -1 last_tsl_info["peak_price"][:] = np.nan last_tsl_info["stop"][:] = np.nan last_tsl_info["th"][:] = np.nan last_tsl_info["exit_price"][:] = -1 last_tsl_info["exit_size"][:] = np.nan last_tsl_info["exit_size_type"][:] = -1 last_tsl_info["exit_type"][:] = -1 last_tsl_info["order_type"][:] = -1 last_tsl_info["limit_delta"][:] = np.nan last_tsl_info["delta_format"][:] = -1 last_tsl_info["ladder"][:] = -1 last_tsl_info["step"][:] = -1 last_tsl_info["step_idx"][:] = -1 last_tp_info = np.empty(target_shape[1], dtype=tp_info_dt) last_tp_info["init_idx"][:] = -1 last_tp_info["init_price"][:] = np.nan last_tp_info["init_position"][:] = np.nan last_tp_info["stop"][:] = np.nan last_tp_info["exit_price"][:] = -1 last_tp_info["exit_size"][:] = np.nan last_tp_info["exit_size_type"][:] = -1 last_tp_info["exit_type"][:] = -1 last_tp_info["order_type"][:] = -1 last_tp_info["limit_delta"][:] = np.nan last_tp_info["delta_format"][:] = -1 last_tp_info["ladder"][:] = -1 last_tp_info["step"][:] = -1 last_tp_info["step_idx"][:] = -1 last_td_info = np.empty(target_shape[1], dtype=time_info_dt) last_td_info["init_idx"][:] = -1 last_td_info["init_position"][:] = np.nan last_td_info["stop"][:] = -1 last_td_info["exit_price"][:] = -1 last_td_info["exit_size"][:] = np.nan last_td_info["exit_size_type"][:] = -1 last_td_info["exit_type"][:] = -1 last_td_info["order_type"][:] = -1 last_td_info["limit_delta"][:] = np.nan last_td_info["delta_format"][:] = -1 last_td_info["time_delta_format"][:] = -1 last_td_info["ladder"][:] = -1 last_td_info["step"][:] = -1 last_td_info["step_idx"][:] = -1 last_dt_info = np.empty(target_shape[1], dtype=time_info_dt) last_dt_info["init_idx"][:] = -1 last_dt_info["init_position"][:] = np.nan last_dt_info["stop"][:] = -1 last_dt_info["exit_price"][:] = -1 last_dt_info["exit_size"][:] = np.nan last_dt_info["exit_size_type"][:] = -1 last_dt_info["exit_type"][:] = -1 last_dt_info["order_type"][:] = -1 last_dt_info["limit_delta"][:] = np.nan last_dt_info["delta_format"][:] = -1 last_dt_info["time_delta_format"][:] = -1 last_dt_info["ladder"][:] = -1 last_dt_info["step"][:] = -1 last_dt_info["step_idx"][:] = -1 else: last_sl_info = np.empty(0, dtype=sl_info_dt) last_tsl_info = np.empty(0, dtype=tsl_info_dt) last_tp_info = np.empty(0, dtype=tp_info_dt) last_td_info = np.empty(0, dtype=time_info_dt) last_dt_info = np.empty(0, dtype=time_info_dt) last_signal = np.empty(target_shape[1], dtype=int_) main_info = np.empty(target_shape[1], dtype=main_info_dt) temp_call_seq = np.empty(target_shape[1], dtype=int_) temp_sort_by = np.empty(target_shape[1], dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(target_shape[0], len(group_lens)), sim_start=sim_start, sim_end=sim_end, ) for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] group_len = to_col - from_col _sim_start = sim_start_[group] _sim_end = sim_end_[group] for i in range(_sim_start, _sim_end): # Add cash if cash_sharing: _cash_deposits = flex_select_nb(cash_deposits_, i, group) if _cash_deposits < 0: _cash_deposits = max(_cash_deposits, -last_cash[group]) last_cash[group] += _cash_deposits last_free_cash[group] += _cash_deposits last_cash_deposits[group] = _cash_deposits if track_cash_deposits: cash_deposits_out[i, group] += _cash_deposits else: for col in range(from_col, to_col): _cash_deposits = flex_select_nb(cash_deposits_, i, col) if _cash_deposits < 0: _cash_deposits = max(_cash_deposits, -last_cash[col]) last_cash[col] += _cash_deposits last_free_cash[col] += _cash_deposits last_cash_deposits[col] = _cash_deposits if track_cash_deposits: cash_deposits_out[i, col] += _cash_deposits # Update valuation price using current open for c in range(group_len): col = from_col + c _open = flex_select_nb(open_, i, col) if not np.isnan(_open) or not ffill_val_price: last_val_price[col] = _open # Update value and return if cash_sharing: group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - last_cash_deposits[group], ) else: for col in range(from_col, to_col): group_value = last_cash[col] if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[col] = group_value last_return[col] = get_return_nb( input_value=prev_close_value[col], output_value=last_value[col] - last_cash_deposits[col], ) # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Get signals skip = skip_empty for c in range(group_len): col = from_col + c signal_ctx = SignalContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, track_cash_deposits=track_cash_deposits, cash_deposits_out=cash_deposits_out, track_cash_earnings=track_cash_earnings, cash_earnings_out=cash_earnings_out, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, last_limit_info=last_limit_info, last_sl_info=last_sl_info, last_tsl_info=last_tsl_info, last_tp_info=last_tp_info, last_td_info=last_td_info, last_dt_info=last_dt_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, col=col, ) is_long_entry, is_long_exit, is_short_entry, is_short_exit = signal_func_nb(signal_ctx, *signal_args) # Update limit and stop prices _i = i - abs(flex_select_nb(from_ago_, i, col)) if _i < 0: _price = np.nan else: _price = flex_select_nb(price_, _i, col) last_limit_info["init_price"][col] = resolve_dyn_limit_price_nb( val_price=last_val_price[col], price=_price, limit_price=last_limit_info["init_price"][col], ) last_sl_info["init_price"][col] = resolve_dyn_stop_entry_price_nb( val_price=last_val_price[col], price=_price, stop_entry_price=last_sl_info["init_price"][col], ) last_tsl_info["init_price"][col] = resolve_dyn_stop_entry_price_nb( val_price=last_val_price[col], price=_price, stop_entry_price=last_tsl_info["init_price"][col], ) last_tsl_info["peak_price"][col] = resolve_dyn_stop_entry_price_nb( val_price=last_val_price[col], price=_price, stop_entry_price=last_tsl_info["peak_price"][col], ) last_tp_info["init_price"][col] = resolve_dyn_stop_entry_price_nb( val_price=last_val_price[col], price=_price, stop_entry_price=last_tp_info["init_price"][col], ) limit_signal = is_limit_active_nb( init_idx=last_limit_info["init_idx"][col], init_price=last_limit_info["init_price"][col], ) if not use_stops: sl_stop_signal = False tsl_stop_signal = False tp_stop_signal = False td_stop_signal = False dt_stop_signal = False else: sl_stop_signal = is_stop_active_nb( init_idx=last_sl_info["init_idx"][col], stop=last_sl_info["stop"][col], ) tsl_stop_signal = is_stop_active_nb( init_idx=last_tsl_info["init_idx"][col], stop=last_tsl_info["stop"][col], ) tp_stop_signal = is_stop_active_nb( init_idx=last_tp_info["init_idx"][col], stop=last_tp_info["stop"][col], ) td_stop_signal = is_time_stop_active_nb( init_idx=last_td_info["init_idx"][col], stop=last_td_info["stop"][col], ) dt_stop_signal = is_time_stop_active_nb( init_idx=last_dt_info["init_idx"][col], stop=last_dt_info["stop"][col], ) # Pack signals into a single integer last_signal[col] = ( (is_long_entry << 10) | (is_long_exit << 9) | (is_short_entry << 8) | (is_short_exit << 7) | (limit_signal << 6) | (sl_stop_signal << 5) | (tsl_stop_signal << 4) | (tp_stop_signal << 3) | (td_stop_signal << 2) | (dt_stop_signal << 1) ) if last_signal[col] > 0: skip = False if not skip: # Update value and return if cash_sharing: group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - last_cash_deposits[group], ) else: for col in range(from_col, to_col): group_value = last_cash[col] if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[col] = group_value last_return[col] = get_return_nb( input_value=prev_close_value[col], output_value=last_value[col] - last_cash_deposits[col], ) # Get size and value of each order for c in range(group_len): col = from_col + c # Set defaults main_info["bar_zone"][col] = -1 main_info["signal_idx"][col] = -1 main_info["creation_idx"][col] = -1 main_info["idx"][col] = i main_info["val_price"][col] = np.nan main_info["price"][col] = np.nan main_info["size"][col] = np.nan main_info["size_type"][col] = -1 main_info["direction"][col] = -1 main_info["type"][col] = -1 main_info["stop_type"][col] = -1 temp_sort_by[col] = 0.0 # Unpack a single integer into signals is_long_entry = (last_signal[col] >> 10) & 1 is_long_exit = (last_signal[col] >> 9) & 1 is_short_entry = (last_signal[col] >> 8) & 1 is_short_exit = (last_signal[col] >> 7) & 1 limit_signal = (last_signal[col] >> 6) & 1 sl_stop_signal = (last_signal[col] >> 5) & 1 tsl_stop_signal = (last_signal[col] >> 4) & 1 tp_stop_signal = (last_signal[col] >> 3) & 1 td_stop_signal = (last_signal[col] >> 2) & 1 dt_stop_signal = (last_signal[col] >> 1) & 1 any_user_signal = is_long_entry or is_long_exit or is_short_entry or is_short_exit any_limit_signal = limit_signal any_stop_signal = ( sl_stop_signal or tsl_stop_signal or tp_stop_signal or td_stop_signal or dt_stop_signal ) # Set initial info exec_limit_set = False exec_limit_set_on_open = False exec_limit_set_on_close = False exec_limit_signal_i = -1 exec_limit_creation_i = -1 exec_limit_init_i = -1 exec_limit_val_price = np.nan exec_limit_price = np.nan exec_limit_size = np.nan exec_limit_size_type = -1 exec_limit_direction = -1 exec_limit_stop_type = -1 exec_limit_bar_zone = -1 exec_stop_set = False exec_stop_set_on_open = False exec_stop_set_on_close = False exec_stop_init_i = -1 exec_stop_val_price = np.nan exec_stop_price = np.nan exec_stop_size = np.nan exec_stop_size_type = -1 exec_stop_direction = -1 exec_stop_type = -1 exec_stop_stop_type = -1 exec_stop_delta = np.nan exec_stop_delta_format = -1 exec_stop_make_limit = False exec_stop_bar_zone = -1 user_on_open = False user_on_close = False exec_user_set = False exec_user_val_price = np.nan exec_user_price = np.nan exec_user_size = np.nan exec_user_size_type = -1 exec_user_direction = -1 exec_user_type = -1 exec_user_stop_type = -1 exec_user_make_limit = False exec_user_bar_zone = -1 # Resolve the current bar _i = i - abs(flex_select_nb(from_ago_, i, col)) _open = flex_select_nb(open_, i, col) _high = flex_select_nb(high_, i, col) _low = flex_select_nb(low_, i, col) _close = flex_select_nb(close_, i, col) _high, _low = resolve_hl_nb( open=_open, high=_high, low=_low, close=_close, ) # Process the limit signal if any_limit_signal: # Check whether the limit price was hit _signal_i = last_limit_info["signal_idx"][col] _creation_i = last_limit_info["creation_idx"][col] _init_i = last_limit_info["init_idx"][col] _price = last_limit_info["init_price"][col] _size = last_limit_info["init_size"][col] _size_type = last_limit_info["init_size_type"][col] _direction = last_limit_info["init_direction"][col] _stop_type = last_limit_info["init_stop_type"][col] _delta = last_limit_info["delta"][col] _delta_format = last_limit_info["delta_format"][col] _tif = last_limit_info["tif"][col] _expiry = last_limit_info["expiry"][col] _time_delta_format = last_limit_info["time_delta_format"][col] _reverse = last_limit_info["reverse"][col] _order_price = last_limit_info["order_price"][col] limit_expired_on_open, limit_expired = check_limit_expired_nb( creation_idx=_creation_i, i=i, tif=_tif, expiry=_expiry, time_delta_format=_time_delta_format, index=index, freq=freq, ) limit_price, limit_hit_on_open, limit_hit = check_limit_hit_nb( open=_open, high=_high, low=_low, close=_close, price=_price, size=_size, direction=_direction, limit_delta=_delta, delta_format=_delta_format, limit_reverse=_reverse, can_use_ohlc=True, check_open=True, hard_limit=_order_price == LimitOrderPrice.HardLimit, ) # Resolve the price limit_price = resolve_limit_order_price_nb( limit_price=limit_price, close=_close, limit_order_price=_order_price, ) if limit_expired_on_open or (not limit_hit_on_open and limit_expired): # Expired limit signal any_limit_signal = False last_limit_info["signal_idx"][col] = -1 last_limit_info["creation_idx"][col] = -1 last_limit_info["init_idx"][col] = -1 last_limit_info["init_price"][col] = np.nan last_limit_info["init_size"][col] = np.nan last_limit_info["init_size_type"][col] = -1 last_limit_info["init_direction"][col] = -1 last_limit_info["delta"][col] = np.nan last_limit_info["delta_format"][col] = -1 last_limit_info["tif"][col] = -1 last_limit_info["expiry"][col] = -1 last_limit_info["time_delta_format"][col] = -1 last_limit_info["reverse"][col] = False last_limit_info["order_price"][col] = np.nan else: # Save info if limit_hit: # Executable limit signal exec_limit_set = True exec_limit_set_on_open = limit_hit_on_open exec_limit_set_on_close = _order_price == LimitOrderPrice.Close exec_limit_signal_i = _signal_i exec_limit_creation_i = _creation_i exec_limit_init_i = _init_i if np.isinf(limit_price) and limit_price > 0: exec_limit_val_price = _close elif np.isinf(limit_price) and limit_price < 0: exec_limit_val_price = _open else: exec_limit_val_price = limit_price exec_limit_price = limit_price exec_limit_size = _size exec_limit_size_type = _size_type exec_limit_direction = _direction exec_limit_stop_type = _stop_type # Process the stop signal if any_stop_signal: # Check SL sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = np.nan, False, False if sl_stop_signal: # Check against high and low sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_sl_info["init_price"][col], stop=last_sl_info["stop"][col], delta_format=last_sl_info["delta_format"][col], hit_below=True, hard_stop=last_sl_info["exit_price"][col] == StopExitPrice.HardStop, ) # Check TSL and TTP tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = np.nan, False, False if tsl_stop_signal: # Update peak price using open if last_position[col] > 0: if _open > last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _open else: if _open < last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _open # Check threshold against previous bars and open if np.isnan(last_tsl_info["th"][col]): th_hit = True else: th_hit = check_tsl_th_hit_nb( is_position_long=last_position[col] > 0, init_price=last_tsl_info["init_price"][col], peak_price=last_tsl_info["peak_price"][col], threshold=last_tsl_info["th"][col], delta_format=last_tsl_info["delta_format"][col], ) if th_hit: tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_tsl_info["peak_price"][col], stop=last_tsl_info["stop"][col], delta_format=last_tsl_info["delta_format"][col], hit_below=True, hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop, ) # Update peak price using full bar if last_position[col] > 0: if _high > last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _high - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _high else: if _low < last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _low - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _low if not tsl_stop_hit: # Check threshold against full bar if not th_hit: if np.isnan(last_tsl_info["th"][col]): th_hit = True else: th_hit = check_tsl_th_hit_nb( is_position_long=last_position[col] > 0, init_price=last_tsl_info["init_price"][col], peak_price=last_tsl_info["peak_price"][col], threshold=last_tsl_info["th"][col], delta_format=last_tsl_info["delta_format"][col], ) if th_hit: # Check threshold against close tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_tsl_info["peak_price"][col], stop=last_tsl_info["stop"][col], delta_format=last_tsl_info["delta_format"][col], hit_below=True, can_use_ohlc=False, hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop, ) # Check TP tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = np.nan, False, False if tp_stop_signal: tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = check_stop_hit_nb( open=_open, high=_high, low=_low, close=_close, is_position_long=last_position[col] > 0, init_price=last_tp_info["init_price"][col], stop=last_tp_info["stop"][col], delta_format=last_tp_info["delta_format"][col], hit_below=False, hard_stop=last_tp_info["exit_price"][col] == StopExitPrice.HardStop, ) # Check TD td_stop_price, td_stop_hit_on_open, td_stop_hit = np.nan, False, False if td_stop_signal: td_stop_hit_on_open, td_stop_hit = check_td_stop_hit_nb( init_idx=last_td_info["init_idx"][col], i=i, stop=last_td_info["stop"][col], time_delta_format=last_td_info["time_delta_format"][col], index=index, freq=freq, ) if np.isnan(_open): td_stop_hit_on_open = False if td_stop_hit_on_open: td_stop_price = _open else: td_stop_price = _close # Check DT dt_stop_price, dt_stop_hit_on_open, dt_stop_hit = np.nan, False, False if dt_stop_signal: dt_stop_hit_on_open, dt_stop_hit = check_dt_stop_hit_nb( i=i, stop=last_dt_info["stop"][col], time_delta_format=last_dt_info["time_delta_format"][col], index=index, freq=freq, ) if np.isnan(_open): dt_stop_hit_on_open = False if dt_stop_hit_on_open: dt_stop_price = _open else: dt_stop_price = _close # Resolve the stop signal sl_hit = False tsl_hit = False tp_hit = False td_hit = False dt_hit = False if sl_stop_hit_on_open: sl_hit = True elif tsl_stop_hit_on_open: tsl_hit = True elif tp_stop_hit_on_open: tp_hit = True elif td_stop_hit_on_open: td_hit = True elif dt_stop_hit_on_open: dt_hit = True elif sl_stop_hit: sl_hit = True elif tsl_stop_hit: tsl_hit = True elif tp_stop_hit: tp_hit = True elif td_stop_hit: td_hit = True elif dt_stop_hit: dt_hit = True if sl_hit: stop_price, stop_hit_on_open, stop_hit = sl_stop_price, sl_stop_hit_on_open, sl_stop_hit _stop_type = StopType.SL _init_i = last_sl_info["init_idx"][col] _stop_exit_price = last_sl_info["exit_price"][col] _stop_exit_size = last_sl_info["exit_size"][col] _stop_exit_size_type = last_sl_info["exit_size_type"][col] _stop_exit_type = last_sl_info["exit_type"][col] _stop_order_type = last_sl_info["order_type"][col] _limit_delta = last_sl_info["limit_delta"][col] _delta_format = last_sl_info["delta_format"][col] _ladder = last_sl_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_sl_info["step"][col] if step < n_sl_steps: _stop_exit_size = get_stop_ladder_exit_size_nb( stop_=sl_stop_, step=step, col=col, init_price=last_sl_info["init_price"][col], init_position=last_sl_info["init_position"][col], position_now=last_position[col], ladder=_ladder, delta_format=last_sl_info["delta_format"][col], hit_below=True, ) _stop_exit_size_type = SizeType.Amount elif tsl_hit: stop_price, stop_hit_on_open, stop_hit = ( tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit, ) if np.isnan(last_tsl_info["th"][col]): _stop_type = StopType.TSL else: _stop_type = StopType.TTP _init_i = last_tsl_info["init_idx"][col] _stop_exit_price = last_tsl_info["exit_price"][col] _stop_exit_size = last_tsl_info["exit_size"][col] _stop_exit_size_type = last_tsl_info["exit_size_type"][col] _stop_exit_type = last_tsl_info["exit_type"][col] _stop_order_type = last_tsl_info["order_type"][col] _limit_delta = last_tsl_info["limit_delta"][col] _delta_format = last_tsl_info["delta_format"][col] _ladder = last_tsl_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_tsl_info["step"][col] if step < n_tsl_steps: _stop_exit_size = get_stop_ladder_exit_size_nb( stop_=tsl_stop_, step=step, col=col, init_price=last_tsl_info["init_price"][col], init_position=last_tsl_info["init_position"][col], position_now=last_position[col], ladder=_ladder, delta_format=last_tsl_info["delta_format"][col], hit_below=True, ) _stop_exit_size_type = SizeType.Amount elif tp_hit: stop_price, stop_hit_on_open, stop_hit = tp_stop_price, tp_stop_hit_on_open, tp_stop_hit _stop_type = StopType.TP _init_i = last_tp_info["init_idx"][col] _stop_exit_price = last_tp_info["exit_price"][col] _stop_exit_size = last_tp_info["exit_size"][col] _stop_exit_size_type = last_tp_info["exit_size_type"][col] _stop_exit_type = last_tp_info["exit_type"][col] _stop_order_type = last_tp_info["order_type"][col] _limit_delta = last_tp_info["limit_delta"][col] _delta_format = last_tp_info["delta_format"][col] _ladder = last_tp_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_tp_info["step"][col] if step < n_tp_steps: _stop_exit_size = get_stop_ladder_exit_size_nb( stop_=tp_stop_, step=step, col=col, init_price=last_tp_info["init_price"][col], init_position=last_tp_info["init_position"][col], position_now=last_position[col], ladder=_ladder, delta_format=last_tp_info["delta_format"][col], hit_below=True, ) _stop_exit_size_type = SizeType.Amount elif td_hit: stop_price, stop_hit_on_open, stop_hit = td_stop_price, td_stop_hit_on_open, td_stop_hit _stop_type = StopType.TD _init_i = last_td_info["init_idx"][col] _stop_exit_price = last_td_info["exit_price"][col] _stop_exit_size = last_td_info["exit_size"][col] _stop_exit_size_type = last_td_info["exit_size_type"][col] _stop_exit_type = last_td_info["exit_type"][col] _stop_order_type = last_td_info["order_type"][col] _limit_delta = last_td_info["limit_delta"][col] _delta_format = last_td_info["delta_format"][col] _ladder = last_td_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_td_info["step"][col] if step < n_td_steps: _stop_exit_size = get_time_stop_ladder_exit_size_nb( stop_=td_stop_, step=step, col=col, init_idx=last_td_info["init_idx"][col], init_position=last_td_info["init_position"][col], position_now=last_position[col], ladder=_ladder, time_delta_format=last_td_info["time_delta_format"][col], index=index, ) _stop_exit_size_type = SizeType.Amount elif dt_hit: stop_price, stop_hit_on_open, stop_hit = dt_stop_price, dt_stop_hit_on_open, dt_stop_hit _stop_type = StopType.DT _init_i = last_dt_info["init_idx"][col] _stop_exit_price = last_dt_info["exit_price"][col] _stop_exit_size = last_dt_info["exit_size"][col] _stop_exit_size_type = last_dt_info["exit_size_type"][col] _stop_exit_type = last_dt_info["exit_type"][col] _stop_order_type = last_dt_info["order_type"][col] _limit_delta = last_dt_info["limit_delta"][col] _delta_format = last_dt_info["delta_format"][col] _ladder = last_dt_info["ladder"][col] if np.isnan(_stop_exit_size): if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic: step = last_dt_info["step"][col] if step < n_dt_steps: _stop_exit_size = get_time_stop_ladder_exit_size_nb( stop_=dt_stop_, step=step, col=col, init_idx=last_dt_info["init_idx"][col], init_position=last_dt_info["init_position"][col], position_now=last_position[col], ladder=_ladder, time_delta_format=last_dt_info["time_delta_format"][col], index=index, ) _stop_exit_size_type = SizeType.Amount else: stop_price, stop_hit_on_open, stop_hit = np.nan, False, False if stop_hit: # Stop price was hit # Resolve the final stop signal _accumulate = flex_select_nb(accumulate_, i, col) _size = flex_select_nb(size_, i, col) _size_type = flex_select_nb(size_type_, i, col) if not np.isnan(_stop_exit_size): _accumulate = True if _stop_exit_type == StopExitType.Close: _stop_exit_type = StopExitType.CloseReduce _size = _stop_exit_size if _stop_exit_size_type != -1: _size_type = _stop_exit_size_type ( stop_is_long_entry, stop_is_long_exit, stop_is_short_entry, stop_is_short_exit, _accumulate, ) = generate_stop_signal_nb( position_now=last_position[col], stop_exit_type=_stop_exit_type, accumulate=_accumulate, ) # Resolve the price _price = resolve_stop_exit_price_nb( stop_price=stop_price, close=_close, stop_exit_price=_stop_exit_price, ) # Convert both signals to size (direction-aware), size type, and direction _size, _size_type, _direction = signal_to_size_nb( position_now=last_position[col], val_price_now=_price, value_now=last_value[group], is_long_entry=stop_is_long_entry, is_long_exit=stop_is_long_exit, is_short_entry=stop_is_short_entry, is_short_exit=stop_is_short_exit, size=_size, size_type=_size_type, accumulate=_accumulate, ) if not np.isnan(_size): # Executable stop signal can_execute = True if _stop_order_type == OrderType.Limit: # Use close to check whether the limit price was hit if _stop_exit_price == StopExitPrice.Close: # Cannot place a limit order at the close price and execute right away can_execute = False if can_execute: limit_price, _, can_execute = check_limit_hit_nb( open=_open, high=_high, low=_low, close=_close, price=_price, size=_size, direction=_direction, limit_delta=_limit_delta, delta_format=_delta_format, limit_reverse=False, can_use_ohlc=stop_hit_on_open, check_open=False, hard_limit=False, ) if can_execute: _price = limit_price # Save info exec_stop_set = True exec_stop_set_on_open = stop_hit_on_open exec_stop_set_on_close = _stop_exit_price == StopExitPrice.Close exec_stop_init_i = _init_i if np.isinf(_price) and _price > 0: exec_stop_val_price = _close elif np.isinf(_price) and _price < 0: exec_stop_val_price = _open else: exec_stop_val_price = _price exec_stop_price = _price exec_stop_size = _size exec_stop_size_type = _size_type exec_stop_direction = _direction exec_stop_type = _stop_order_type exec_stop_stop_type = _stop_type exec_stop_delta = _limit_delta exec_stop_delta_format = _delta_format exec_stop_make_limit = not can_execute # Process user signal if any_user_signal: if _i < 0: _price = np.nan _size = np.nan _size_type = -1 _direction = -1 else: _accumulate = flex_select_nb(accumulate_, _i, col) if is_long_entry or is_short_entry: # Resolve any single-direction conflicts _upon_long_conflict = flex_select_nb(upon_long_conflict_, _i, col) is_long_entry, is_long_exit = resolve_signal_conflict_nb( position_now=last_position[col], is_entry=is_long_entry, is_exit=is_long_exit, direction=Direction.LongOnly, conflict_mode=_upon_long_conflict, ) _upon_short_conflict = flex_select_nb(upon_short_conflict_, _i, col) is_short_entry, is_short_exit = resolve_signal_conflict_nb( position_now=last_position[col], is_entry=is_short_entry, is_exit=is_short_exit, direction=Direction.ShortOnly, conflict_mode=_upon_short_conflict, ) # Resolve any multi-direction conflicts _upon_dir_conflict = flex_select_nb(upon_dir_conflict_, _i, col) is_long_entry, is_short_entry = resolve_dir_conflict_nb( position_now=last_position[col], is_long_entry=is_long_entry, is_short_entry=is_short_entry, upon_dir_conflict=_upon_dir_conflict, ) # Resolve an opposite entry _upon_opposite_entry = flex_select_nb(upon_opposite_entry_, _i, col) ( is_long_entry, is_long_exit, is_short_entry, is_short_exit, _accumulate, ) = resolve_opposite_entry_nb( position_now=last_position[col], is_long_entry=is_long_entry, is_long_exit=is_long_exit, is_short_entry=is_short_entry, is_short_exit=is_short_exit, upon_opposite_entry=_upon_opposite_entry, accumulate=_accumulate, ) # Resolve the price _price = flex_select_nb(price_, _i, col) # Convert both signals to size (direction-aware), size type, and direction _val_price = flex_select_nb(val_price_, i, col) if np.isinf(_val_price) and _val_price > 0: if np.isinf(_price) and _price > 0: _val_price = _close elif np.isinf(_price) and _price < 0: _val_price = _open else: _val_price = _price elif np.isnan(_val_price) or (np.isinf(_val_price) and _val_price < 0): _val_price = last_val_price[col] _size, _size_type, _direction = signal_to_size_nb( position_now=last_position[col], val_price_now=_val_price, value_now=last_value[group], is_long_entry=is_long_entry, is_long_exit=is_long_exit, is_short_entry=is_short_entry, is_short_exit=is_short_exit, size=flex_select_nb(size_, _i, col), size_type=flex_select_nb(size_type_, _i, col), accumulate=_accumulate, ) if np.isinf(_price): if _price > 0: user_on_close = True else: user_on_open = True if not np.isnan(_size): # Executable user signal can_execute = True _order_type = flex_select_nb(order_type_, _i, col) if _order_type == OrderType.Limit: # Use close to check whether the limit price was hit can_use_ohlc = False if np.isinf(_price): if _price > 0: # Cannot place a limit order at the close price and execute right away _price = _close can_execute = False else: can_use_ohlc = True _price = _open if can_execute: _limit_delta = flex_select_nb(limit_delta_, _i, col) _delta_format = flex_select_nb(delta_format_, _i, col) _limit_reverse = flex_select_nb(limit_reverse_, _i, col) limit_price, _, can_execute = check_limit_hit_nb( open=_open, high=_high, low=_low, close=_close, price=_price, size=_size, direction=_direction, limit_delta=_limit_delta, delta_format=_delta_format, limit_reverse=_limit_reverse, can_use_ohlc=can_use_ohlc, check_open=False, hard_limit=False, ) if can_execute: _price = limit_price # Save info exec_user_set = True exec_user_val_price = _val_price exec_user_price = _price exec_user_size = _size exec_user_size_type = _size_type exec_user_direction = _direction exec_user_type = _order_type exec_user_stop_type = -1 exec_user_make_limit = not can_execute if ( exec_limit_set or exec_stop_set or exec_user_set or ((any_limit_signal or any_stop_signal) and any_user_signal) ): # Choose the main executable signal # Priority: limit -> stop -> user # Check whether the main signal comes on open keep_limit = True keep_stop = True execute_limit = False execute_stop = False execute_user = False if exec_limit_set_on_open: keep_limit = False keep_stop = False execute_limit = True if exec_limit_set_on_close: exec_limit_bar_zone = BarZone.Close else: exec_limit_bar_zone = BarZone.Open elif exec_stop_set_on_open: keep_limit = False keep_stop = _ladder execute_stop = True if exec_stop_set_on_close: exec_stop_bar_zone = BarZone.Close else: exec_stop_bar_zone = BarZone.Open elif any_user_signal and user_on_open: execute_user = True if any_limit_signal and (execute_user or not exec_user_set): stop_size = get_diraware_size_nb( size=last_limit_info["init_size"][col], direction=last_limit_info["init_direction"][col], ) keep_limit, execute_user = resolve_pending_conflict_nb( is_pending_long=stop_size >= 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col), ) if any_stop_signal and (execute_user or not exec_user_set): keep_stop, execute_user = resolve_pending_conflict_nb( is_pending_long=last_position[col] < 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col), ) if not exec_user_set: execute_user = False if execute_user: exec_user_bar_zone = BarZone.Open if not execute_limit and not execute_stop and not execute_user: # Check whether the main signal comes in the middle of the bar if exec_limit_set and not exec_limit_set_on_open and keep_limit: keep_limit = False keep_stop = False execute_limit = True exec_limit_bar_zone = BarZone.Middle elif ( exec_stop_set and not exec_stop_set_on_open and not exec_stop_set_on_close and keep_stop ): keep_limit = False keep_stop = _ladder execute_stop = True exec_stop_bar_zone = BarZone.Middle elif any_user_signal and not user_on_open and not user_on_close: execute_user = True if any_limit_signal and keep_limit and (execute_user or not exec_user_set): stop_size = get_diraware_size_nb( size=last_limit_info["init_size"][col], direction=last_limit_info["init_direction"][col], ) keep_limit, execute_user = resolve_pending_conflict_nb( is_pending_long=stop_size >= 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col), ) if any_stop_signal and keep_stop and (execute_user or not exec_user_set): keep_stop, execute_user = resolve_pending_conflict_nb( is_pending_long=last_position[col] < 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col), ) if not exec_user_set: execute_user = False if execute_user: exec_user_bar_zone = BarZone.Middle if not execute_limit and not execute_stop and not execute_user: # Check whether the main signal comes on close if exec_stop_set_on_close and keep_stop: keep_limit = False keep_stop = _ladder execute_stop = True exec_stop_bar_zone = BarZone.Close elif any_user_signal and user_on_close: execute_user = True if any_limit_signal and keep_limit and (execute_user or not exec_user_set): stop_size = get_diraware_size_nb( size=last_limit_info["init_size"][col], direction=last_limit_info["init_direction"][col], ) keep_limit, execute_user = resolve_pending_conflict_nb( is_pending_long=stop_size >= 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col), ) if any_stop_signal and keep_stop and (execute_user or not exec_user_set): keep_stop, execute_user = resolve_pending_conflict_nb( is_pending_long=last_position[col] < 0, is_user_long=is_long_entry or is_short_exit, upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col), upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col), ) if not exec_user_set: execute_user = False if execute_user: exec_user_bar_zone = BarZone.Close # Process the limit signal if execute_limit: # Execute the signal main_info["bar_zone"][col] = exec_limit_bar_zone main_info["signal_idx"][col] = exec_limit_signal_i main_info["creation_idx"][col] = exec_limit_creation_i main_info["idx"][col] = exec_limit_init_i main_info["val_price"][col] = exec_limit_val_price main_info["price"][col] = exec_limit_price main_info["size"][col] = exec_limit_size main_info["size_type"][col] = exec_limit_size_type main_info["direction"][col] = exec_limit_direction main_info["type"][col] = OrderType.Limit main_info["stop_type"][col] = exec_limit_stop_type if execute_limit or (any_limit_signal and not keep_limit): # Clear the pending info any_limit_signal = False last_limit_info["signal_idx"][col] = -1 last_limit_info["creation_idx"][col] = -1 last_limit_info["init_idx"][col] = -1 last_limit_info["init_price"][col] = np.nan last_limit_info["init_size"][col] = np.nan last_limit_info["init_size_type"][col] = -1 last_limit_info["init_direction"][col] = -1 last_limit_info["init_stop_type"][col] = -1 last_limit_info["delta"][col] = np.nan last_limit_info["delta_format"][col] = -1 last_limit_info["tif"][col] = -1 last_limit_info["expiry"][col] = -1 last_limit_info["time_delta_format"][col] = -1 last_limit_info["reverse"][col] = False last_limit_info["order_price"][col] = np.nan # Process the stop signal if execute_stop: # Execute the signal if exec_stop_make_limit: if any_limit_signal: raise ValueError("Only one active limit signal is allowed at a time") _limit_tif = flex_select_nb(limit_tif_, i, col) _limit_expiry = flex_select_nb(limit_expiry_, i, col) _time_delta_format = flex_select_nb(time_delta_format_, i, col) _limit_order_price = flex_select_nb(limit_order_price_, i, col) last_limit_info["signal_idx"][col] = exec_stop_init_i last_limit_info["creation_idx"][col] = i last_limit_info["init_idx"][col] = i last_limit_info["init_price"][col] = exec_stop_price last_limit_info["init_size"][col] = exec_stop_size last_limit_info["init_size_type"][col] = exec_stop_size_type last_limit_info["init_direction"][col] = exec_stop_direction last_limit_info["init_stop_type"][col] = exec_stop_stop_type last_limit_info["delta"][col] = exec_stop_delta last_limit_info["delta_format"][col] = exec_stop_delta_format last_limit_info["tif"][col] = _limit_tif last_limit_info["expiry"][col] = _limit_expiry last_limit_info["time_delta_format"][col] = _time_delta_format last_limit_info["reverse"][col] = False last_limit_info["order_price"][col] = _limit_order_price else: main_info["bar_zone"][col] = exec_stop_bar_zone main_info["signal_idx"][col] = exec_stop_init_i main_info["creation_idx"][col] = i main_info["idx"][col] = i main_info["val_price"][col] = exec_stop_val_price main_info["price"][col] = exec_stop_price main_info["size"][col] = exec_stop_size main_info["size_type"][col] = exec_stop_size_type main_info["direction"][col] = exec_stop_direction main_info["type"][col] = exec_stop_type main_info["stop_type"][col] = exec_stop_stop_type if any_stop_signal and not keep_stop: # Clear the pending info any_stop_signal = False last_sl_info["init_idx"][col] = -1 last_sl_info["init_price"][col] = np.nan last_sl_info["init_position"][col] = np.nan last_sl_info["stop"][col] = np.nan last_sl_info["exit_price"][col] = -1 last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = -1 last_sl_info["order_type"][col] = -1 last_sl_info["limit_delta"][col] = np.nan last_sl_info["delta_format"][col] = -1 last_sl_info["ladder"][col] = -1 last_sl_info["step"][col] = -1 last_sl_info["step_idx"][col] = -1 last_tsl_info["init_idx"][col] = -1 last_tsl_info["init_price"][col] = np.nan last_tsl_info["init_position"][col] = np.nan last_tsl_info["peak_idx"][col] = -1 last_tsl_info["peak_price"][col] = np.nan last_tsl_info["stop"][col] = np.nan last_tsl_info["th"][col] = np.nan last_tsl_info["exit_price"][col] = -1 last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = -1 last_tsl_info["order_type"][col] = -1 last_tsl_info["limit_delta"][col] = np.nan last_tsl_info["delta_format"][col] = -1 last_tsl_info["ladder"][col] = -1 last_tsl_info["step"][col] = -1 last_tsl_info["step_idx"][col] = -1 last_tp_info["init_idx"][col] = -1 last_tp_info["init_price"][col] = np.nan last_tp_info["init_position"][col] = np.nan last_tp_info["stop"][col] = np.nan last_tp_info["exit_price"][col] = -1 last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = -1 last_tp_info["order_type"][col] = -1 last_tp_info["limit_delta"][col] = np.nan last_tp_info["delta_format"][col] = -1 last_tp_info["ladder"][col] = -1 last_tp_info["step"][col] = -1 last_tp_info["step_idx"][col] = -1 last_td_info["init_idx"][col] = -1 last_td_info["init_position"][col] = np.nan last_td_info["stop"][col] = -1 last_td_info["exit_price"][col] = -1 last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = -1 last_td_info["order_type"][col] = -1 last_td_info["limit_delta"][col] = np.nan last_td_info["delta_format"][col] = -1 last_td_info["time_delta_format"][col] = -1 last_td_info["ladder"][col] = -1 last_td_info["step"][col] = -1 last_td_info["step_idx"][col] = -1 last_dt_info["init_idx"][col] = -1 last_dt_info["init_position"][col] = np.nan last_dt_info["stop"][col] = -1 last_dt_info["exit_price"][col] = -1 last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = -1 last_dt_info["order_type"][col] = -1 last_dt_info["limit_delta"][col] = np.nan last_dt_info["delta_format"][col] = -1 last_dt_info["time_delta_format"][col] = -1 last_dt_info["ladder"][col] = -1 last_dt_info["step"][col] = -1 last_dt_info["step_idx"][col] = -1 # Process the user signal if execute_user: # Execute the signal if _i >= 0: if exec_user_make_limit: if any_limit_signal: raise ValueError("Only one active limit signal is allowed at a time") _limit_delta = flex_select_nb(limit_delta_, _i, col) _delta_format = flex_select_nb(delta_format_, _i, col) _limit_tif = flex_select_nb(limit_tif_, _i, col) _limit_expiry = flex_select_nb(limit_expiry_, _i, col) _time_delta_format = flex_select_nb(time_delta_format_, _i, col) _limit_reverse = flex_select_nb(limit_reverse_, _i, col) _limit_order_price = flex_select_nb(limit_order_price_, _i, col) last_limit_info["signal_idx"][col] = _i last_limit_info["creation_idx"][col] = i last_limit_info["init_idx"][col] = _i last_limit_info["init_price"][col] = exec_user_price last_limit_info["init_size"][col] = exec_user_size last_limit_info["init_size_type"][col] = exec_user_size_type last_limit_info["init_direction"][col] = exec_user_direction last_limit_info["init_stop_type"][col] = -1 last_limit_info["delta"][col] = _limit_delta last_limit_info["delta_format"][col] = _delta_format last_limit_info["tif"][col] = _limit_tif last_limit_info["expiry"][col] = _limit_expiry last_limit_info["time_delta_format"][col] = _time_delta_format last_limit_info["reverse"][col] = _limit_reverse last_limit_info["order_price"][col] = _limit_order_price else: main_info["bar_zone"][col] = exec_user_bar_zone main_info["signal_idx"][col] = _i main_info["creation_idx"][col] = i main_info["idx"][col] = _i main_info["val_price"][col] = exec_user_val_price main_info["price"][col] = exec_user_price main_info["size"][col] = exec_user_size main_info["size_type"][col] = exec_user_size_type main_info["direction"][col] = exec_user_direction main_info["type"][col] = exec_user_type main_info["stop_type"][col] = exec_user_stop_type skip = skip_empty if skip: for col in range(from_col, to_col): if flex_select_nb(log_, i, col): skip = False break if not np.isnan(main_info["size"][col]): skip = False break if not skip: # Check bar zone and update valuation price bar_zone = -1 same_bar_zone = True same_timing = True for c in range(group_len): col = from_col + c if np.isnan(main_info["size"][col]): continue if bar_zone == -1: bar_zone = main_info["bar_zone"][col] if main_info["bar_zone"][col] != bar_zone: same_bar_zone = False same_timing = False if main_info["bar_zone"][col] == BarZone.Middle: same_timing = False _val_price = main_info["val_price"][col] if not np.isnan(_val_price) or not ffill_val_price: last_val_price[col] = _val_price if cash_sharing: # Dynamically sort by order value -> selling comes first to release funds early if call_seq is None: for c in range(group_len): temp_call_seq[c] = c call_seq_now = temp_call_seq[:group_len] else: call_seq_now = call_seq[i, from_col:to_col] if auto_call_seq: # Sort by order value if not same_timing: raise ValueError("Cannot sort orders by value if they are executed at different times") for c in range(group_len): if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") col = from_col + c if np.isnan(main_info["size"][col]): continue # Approximate order value exec_state = ExecState( cash=last_cash[group] if cash_sharing else last_cash[col], position=last_position[col], debt=last_debt[col], locked_cash=last_locked_cash[col], free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col], val_price=last_val_price[col], value=last_value[group] if cash_sharing else last_value[col], ) temp_sort_by[c] = approx_order_value_nb( exec_state=exec_state, size=main_info["size"][col], size_type=main_info["size_type"][col], direction=main_info["direction"][col], ) insert_argsort_nb(temp_sort_by[:group_len], call_seq_now) else: if not same_bar_zone: # Sort by bar zone for c in range(group_len): if call_seq_now[c] != c: raise ValueError("Call sequence must follow CallSeqType.Default") col = from_col + c if np.isnan(main_info["size"][col]): continue temp_sort_by[c] = main_info["bar_zone"][col] insert_argsort_nb(temp_sort_by[:group_len], call_seq_now) for k in range(group_len): if cash_sharing: c = call_seq_now[k] if c >= group_len: raise ValueError("Call index out of bounds of the group") else: c = k col = from_col + c if skip_empty and np.isnan(main_info["size"][col]): # shortcut continue # Get current values per column position_before = position_now = last_position[col] debt_before = debt_now = last_debt[col] locked_cash_before = locked_cash_now = last_locked_cash[col] val_price_before = val_price_now = last_val_price[col] cash_before = cash_now = last_cash[group] if cash_sharing else last_cash[col] free_cash_before = free_cash_now = ( last_free_cash[group] if cash_sharing else last_free_cash[col] ) value_before = value_now = last_value[group] if cash_sharing else last_value[col] return_before = return_now = last_return[group] if cash_sharing else last_return[col] # Generate the next order _i = main_info["idx"][col] if main_info["type"][col] == OrderType.Limit: _slippage = 0.0 else: _slippage = float(flex_select_nb(slippage_, _i, col)) _min_size = flex_select_nb(min_size_, _i, col) _max_size = flex_select_nb(max_size_, _i, col) _size_type = flex_select_nb(size_type_, _i, col) if _size_type != main_info["size_type"][col]: if not np.isnan(_min_size): _min_size, _ = resolve_size_nb( size=_min_size, size_type=_size_type, position=position_now, val_price=val_price_now, value=value_now, target_size_type=main_info["size_type"][col], as_requirement=True, ) if not np.isnan(_max_size): _max_size, _ = resolve_size_nb( size=_max_size, size_type=_size_type, position=position_now, val_price=val_price_now, value=value_now, target_size_type=main_info["size_type"][col], as_requirement=True, ) order = order_nb( size=main_info["size"][col], price=main_info["price"][col], size_type=main_info["size_type"][col], direction=main_info["direction"][col], fees=flex_select_nb(fees_, _i, col), fixed_fees=flex_select_nb(fixed_fees_, _i, col), slippage=_slippage, min_size=_min_size, max_size=_max_size, size_granularity=flex_select_nb(size_granularity_, _i, col), leverage=flex_select_nb(leverage_, _i, col), leverage_mode=flex_select_nb(leverage_mode_, _i, col), reject_prob=flex_select_nb(reject_prob_, _i, col), price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col), allow_partial=flex_select_nb(allow_partial_, _i, col), raise_reject=flex_select_nb(raise_reject_, _i, col), log=flex_select_nb(log_, _i, col), ) # Process the order price_area = PriceArea( open=flex_select_nb(open_, i, col), high=flex_select_nb(high_, i, col), low=flex_select_nb(low_, i, col), close=flex_select_nb(close_, i, col), ) exec_state = ExecState( cash=cash_now, position=position_now, debt=debt_now, locked_cash=locked_cash_now, free_cash=free_cash_now, val_price=val_price_now, value=value_now, ) order_result, new_exec_state = process_order_nb( group=group, col=col, i=i, exec_state=exec_state, order=order, price_area=price_area, update_value=update_value, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, ) # Append more order information if order_result.status == OrderStatus.Filled and order_counts[col] >= 1: order_records["signal_idx"][order_counts[col] - 1, col] = main_info["signal_idx"][col] order_records["creation_idx"][order_counts[col] - 1, col] = main_info["creation_idx"][col] order_records["type"][order_counts[col] - 1, col] = main_info["type"][col] order_records["stop_type"][order_counts[col] - 1, col] = main_info["stop_type"][col] # Update execution state cash_now = new_exec_state.cash position_now = new_exec_state.position debt_now = new_exec_state.debt locked_cash_now = new_exec_state.locked_cash free_cash_now = new_exec_state.free_cash val_price_now = new_exec_state.val_price value_now = new_exec_state.value # Update position record if fill_pos_info: if order_result.status == OrderStatus.Filled: if order_counts[col] > 0: order_id = order_records["id"][order_counts[col] - 1, col] else: order_id = -1 update_pos_info_nb( last_pos_info[col], i, col, exec_state.position, position_now, order_result, order_id, ) if use_stops: # Update stop price if position_now == 0: # Not in position anymore -> clear stops (irrespective of order success) last_sl_info["init_idx"][col] = -1 last_sl_info["init_price"][col] = np.nan last_sl_info["init_position"][col] = np.nan last_sl_info["stop"][col] = np.nan last_sl_info["exit_price"][col] = -1 last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = -1 last_sl_info["order_type"][col] = -1 last_sl_info["limit_delta"][col] = np.nan last_sl_info["delta_format"][col] = -1 last_sl_info["ladder"][col] = -1 last_sl_info["step"][col] = -1 last_sl_info["step_idx"][col] = -1 last_tsl_info["init_idx"][col] = -1 last_tsl_info["init_price"][col] = np.nan last_tsl_info["init_position"][col] = np.nan last_tsl_info["peak_idx"][col] = -1 last_tsl_info["peak_price"][col] = np.nan last_tsl_info["stop"][col] = np.nan last_tsl_info["th"][col] = np.nan last_tsl_info["exit_price"][col] = -1 last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = -1 last_tsl_info["order_type"][col] = -1 last_tsl_info["limit_delta"][col] = np.nan last_tsl_info["delta_format"][col] = -1 last_tsl_info["ladder"][col] = -1 last_tsl_info["step"][col] = -1 last_tsl_info["step_idx"][col] = -1 last_tp_info["init_idx"][col] = -1 last_tp_info["init_price"][col] = np.nan last_tp_info["init_position"][col] = np.nan last_tp_info["stop"][col] = np.nan last_tp_info["exit_price"][col] = -1 last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = -1 last_tp_info["order_type"][col] = -1 last_tp_info["limit_delta"][col] = np.nan last_tp_info["delta_format"][col] = -1 last_tp_info["ladder"][col] = -1 last_tp_info["step"][col] = -1 last_tp_info["step_idx"][col] = -1 last_td_info["init_idx"][col] = -1 last_td_info["init_position"][col] = np.nan last_td_info["stop"][col] = -1 last_td_info["exit_price"][col] = -1 last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = -1 last_td_info["order_type"][col] = -1 last_td_info["limit_delta"][col] = np.nan last_td_info["delta_format"][col] = -1 last_td_info["time_delta_format"][col] = -1 last_td_info["ladder"][col] = -1 last_td_info["step"][col] = -1 last_td_info["step_idx"][col] = -1 last_dt_info["init_idx"][col] = -1 last_dt_info["init_position"][col] = np.nan last_dt_info["stop"][col] = -1 last_dt_info["exit_price"][col] = -1 last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = -1 last_dt_info["order_type"][col] = -1 last_dt_info["limit_delta"][col] = np.nan last_dt_info["delta_format"][col] = -1 last_dt_info["time_delta_format"][col] = -1 last_dt_info["ladder"][col] = -1 last_dt_info["step"][col] = -1 last_dt_info["step_idx"][col] = -1 else: if main_info["stop_type"][col] == StopType.SL: if last_sl_info["ladder"][col]: step = last_sl_info["step"][col] + 1 last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 if stop_ladder and last_sl_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_sl_steps: last_sl_info["stop"][col] = flex_select_nb(sl_stop_, step, col) last_sl_info["step"][col] = step last_sl_info["step_idx"][col] = i else: last_sl_info["stop"][col] = np.nan last_sl_info["step"][col] = -1 last_sl_info["step_idx"][col] = -1 else: last_sl_info["stop"][col] = np.nan last_sl_info["step"][col] = step last_sl_info["step_idx"][col] = i elif ( main_info["stop_type"][col] == StopType.TSL or main_info["stop_type"][col] == StopType.TTP ): if last_tsl_info["ladder"][col]: step = last_tsl_info["step"][col] + 1 last_tsl_info["step"][col] = step last_tsl_info["step_idx"][col] = i last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 if stop_ladder and last_tsl_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_tsl_steps: last_tsl_info["stop"][col] = flex_select_nb(tsl_stop_, step, col) last_tsl_info["step"][col] = step last_tsl_info["step_idx"][col] = i else: last_tsl_info["stop"][col] = np.nan last_tsl_info["step"][col] = -1 last_tsl_info["step_idx"][col] = -1 else: last_tsl_info["stop"][col] = np.nan last_tsl_info["step"][col] = step last_tsl_info["step_idx"][col] = i elif main_info["stop_type"][col] == StopType.TP: if last_tp_info["ladder"][col]: step = last_tp_info["step"][col] + 1 last_tp_info["step"][col] = step last_tp_info["step_idx"][col] = i last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 if stop_ladder and last_tp_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_tp_steps: last_tp_info["stop"][col] = flex_select_nb(tp_stop_, step, col) last_tp_info["step"][col] = step last_tp_info["step_idx"][col] = i else: last_tp_info["stop"][col] = np.nan last_tp_info["step"][col] = -1 last_tp_info["step_idx"][col] = -1 else: last_tp_info["stop"][col] = np.nan last_tp_info["step"][col] = step last_tp_info["step_idx"][col] = i elif main_info["stop_type"][col] == StopType.TD: if last_td_info["ladder"][col]: step = last_td_info["step"][col] + 1 last_td_info["step"][col] = step last_td_info["step_idx"][col] = i last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 if stop_ladder and last_td_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_td_steps: last_td_info["stop"][col] = flex_select_nb(td_stop_, step, col) last_td_info["step"][col] = step last_td_info["step_idx"][col] = i else: last_td_info["stop"][col] = -1 last_td_info["step"][col] = -1 last_td_info["step_idx"][col] = -1 else: last_td_info["stop"][col] = -1 last_td_info["step"][col] = step last_td_info["step_idx"][col] = i elif main_info["stop_type"][col] == StopType.DT: if last_dt_info["ladder"][col]: step = last_dt_info["step"][col] + 1 last_dt_info["step"][col] = step last_dt_info["step_idx"][col] = i last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 if stop_ladder and last_dt_info["ladder"][col] != StopLadderMode.Dynamic: if step < n_dt_steps: last_dt_info["stop"][col] = flex_select_nb(dt_stop_, step, col) last_dt_info["step"][col] = step last_dt_info["step_idx"][col] = i else: last_dt_info["stop"][col] = -1 last_dt_info["step"][col] = -1 last_dt_info["step_idx"][col] = -1 else: last_dt_info["stop"][col] = -1 last_dt_info["step"][col] = step last_dt_info["step_idx"][col] = i if order_result.status == OrderStatus.Filled and position_now != 0: # Order filled and in position -> possibly set stops _price = main_info["price"][col] _stop_entry_price = flex_select_nb(stop_entry_price_, i, col) if _stop_entry_price < 0: if _stop_entry_price == StopEntryPrice.ValPrice: new_init_price = val_price_now can_use_ohlc = False elif _stop_entry_price == StopEntryPrice.Price: new_init_price = order.price can_use_ohlc = np.isinf(_price) and _price < 0 if np.isinf(new_init_price): if new_init_price > 0: new_init_price = flex_select_nb(close_, i, col) else: new_init_price = flex_select_nb(open_, i, col) elif _stop_entry_price == StopEntryPrice.FillPrice: new_init_price = order_result.price can_use_ohlc = np.isinf(_price) and _price < 0 elif _stop_entry_price == StopEntryPrice.Open: new_init_price = flex_select_nb(open_, i, col) can_use_ohlc = True elif _stop_entry_price == StopEntryPrice.Close: new_init_price = flex_select_nb(close_, i, col) can_use_ohlc = False else: raise ValueError("Invalid StopEntryPrice option") else: new_init_price = _stop_entry_price can_use_ohlc = False if stop_ladder: _sl_stop = flex_select_nb(sl_stop_, 0, col) _tsl_stop = flex_select_nb(tsl_stop_, 0, col) _tp_stop = flex_select_nb(tp_stop_, 0, col) _td_stop = flex_select_nb(td_stop_, 0, col) _dt_stop = flex_select_nb(dt_stop_, 0, col) else: _sl_stop = flex_select_nb(sl_stop_, i, col) _tsl_stop = flex_select_nb(tsl_stop_, i, col) _tp_stop = flex_select_nb(tp_stop_, i, col) _td_stop = flex_select_nb(td_stop_, i, col) _dt_stop = flex_select_nb(dt_stop_, i, col) _tsl_th = flex_select_nb(tsl_th_, i, col) _stop_exit_price = flex_select_nb(stop_exit_price_, i, col) _stop_exit_type = flex_select_nb(stop_exit_type_, i, col) _stop_order_type = flex_select_nb(stop_order_type_, i, col) _stop_limit_delta = flex_select_nb(stop_limit_delta_, i, col) _delta_format = flex_select_nb(delta_format_, i, col) _time_delta_format = flex_select_nb(time_delta_format_, i, col) tsl_updated = False if exec_state.position == 0 or np.sign(position_now) != np.sign(exec_state.position): # Position opened/reversed -> set stops last_sl_info["init_idx"][col] = i last_sl_info["init_price"][col] = new_init_price last_sl_info["init_position"][col] = position_now last_sl_info["stop"][col] = _sl_stop last_sl_info["exit_price"][col] = _stop_exit_price last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = _stop_exit_type last_sl_info["order_type"][col] = _stop_order_type last_sl_info["limit_delta"][col] = _stop_limit_delta last_sl_info["delta_format"][col] = _delta_format last_sl_info["ladder"][col] = stop_ladder last_sl_info["step"][col] = 0 last_sl_info["step_idx"][col] = i tsl_updated = True last_tsl_info["init_idx"][col] = i last_tsl_info["init_price"][col] = new_init_price last_tsl_info["init_position"][col] = position_now last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = new_init_price last_tsl_info["stop"][col] = _tsl_stop last_tsl_info["th"][col] = _tsl_th last_tsl_info["exit_price"][col] = _stop_exit_price last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = _stop_exit_type last_tsl_info["order_type"][col] = _stop_order_type last_tsl_info["limit_delta"][col] = _stop_limit_delta last_tsl_info["delta_format"][col] = _delta_format last_tsl_info["ladder"][col] = stop_ladder last_tsl_info["step"][col] = 0 last_tsl_info["step_idx"][col] = i last_tp_info["init_idx"][col] = i last_tp_info["init_price"][col] = new_init_price last_tp_info["init_position"][col] = position_now last_tp_info["stop"][col] = _tp_stop last_tp_info["exit_price"][col] = _stop_exit_price last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = _stop_exit_type last_tp_info["order_type"][col] = _stop_order_type last_tp_info["limit_delta"][col] = _stop_limit_delta last_tp_info["delta_format"][col] = _delta_format last_tp_info["ladder"][col] = stop_ladder last_tp_info["step"][col] = 0 last_tp_info["step_idx"][col] = i last_td_info["init_idx"][col] = i last_td_info["init_position"][col] = position_now last_td_info["stop"][col] = _td_stop last_td_info["exit_price"][col] = _stop_exit_price last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = _stop_exit_type last_td_info["order_type"][col] = _stop_order_type last_td_info["limit_delta"][col] = _stop_limit_delta last_td_info["delta_format"][col] = _delta_format last_td_info["time_delta_format"][col] = _time_delta_format last_td_info["ladder"][col] = stop_ladder last_td_info["step"][col] = 0 last_td_info["step_idx"][col] = i last_dt_info["init_idx"][col] = i last_dt_info["init_position"][col] = position_now last_dt_info["stop"][col] = _dt_stop last_dt_info["exit_price"][col] = _stop_exit_price last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = _stop_exit_type last_dt_info["order_type"][col] = _stop_order_type last_dt_info["limit_delta"][col] = _stop_limit_delta last_dt_info["delta_format"][col] = _delta_format last_dt_info["time_delta_format"][col] = _time_delta_format last_dt_info["ladder"][col] = stop_ladder last_dt_info["step"][col] = 0 last_dt_info["step_idx"][col] = i elif abs(position_now) > abs(exec_state.position): # Position increased -> keep/override stops _upon_stop_update = flex_select_nb(upon_stop_update_, i, col) if should_update_stop_nb(new_stop=_sl_stop, upon_stop_update=_upon_stop_update): last_sl_info["init_idx"][col] = i last_sl_info["init_price"][col] = new_init_price last_sl_info["init_position"][col] = position_now last_sl_info["stop"][col] = _sl_stop last_sl_info["exit_price"][col] = _stop_exit_price last_sl_info["exit_size"][col] = np.nan last_sl_info["exit_size_type"][col] = -1 last_sl_info["exit_type"][col] = _stop_exit_type last_sl_info["order_type"][col] = _stop_order_type last_sl_info["limit_delta"][col] = _stop_limit_delta last_sl_info["delta_format"][col] = _delta_format last_sl_info["ladder"][col] = stop_ladder last_sl_info["step"][col] = 0 last_sl_info["step_idx"][col] = i if should_update_stop_nb(new_stop=_tsl_stop, upon_stop_update=_upon_stop_update): tsl_updated = True last_tsl_info["init_idx"][col] = i last_tsl_info["init_price"][col] = new_init_price last_tsl_info["init_position"][col] = position_now last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = new_init_price last_tsl_info["stop"][col] = _tsl_stop last_tsl_info["th"][col] = _tsl_th last_tsl_info["exit_price"][col] = _stop_exit_price last_tsl_info["exit_size"][col] = np.nan last_tsl_info["exit_size_type"][col] = -1 last_tsl_info["exit_type"][col] = _stop_exit_type last_tsl_info["order_type"][col] = _stop_order_type last_tsl_info["limit_delta"][col] = _stop_limit_delta last_tsl_info["delta_format"][col] = _delta_format last_tsl_info["ladder"][col] = stop_ladder last_tsl_info["step"][col] = 0 last_tsl_info["step_idx"][col] = i if should_update_stop_nb(new_stop=_tp_stop, upon_stop_update=_upon_stop_update): last_tp_info["init_idx"][col] = i last_tp_info["init_price"][col] = new_init_price last_tp_info["init_position"][col] = position_now last_tp_info["stop"][col] = _tp_stop last_tp_info["exit_price"][col] = _stop_exit_price last_tp_info["exit_size"][col] = np.nan last_tp_info["exit_size_type"][col] = -1 last_tp_info["exit_type"][col] = _stop_exit_type last_tp_info["order_type"][col] = _stop_order_type last_tp_info["limit_delta"][col] = _stop_limit_delta last_tp_info["delta_format"][col] = _delta_format last_tp_info["ladder"][col] = stop_ladder last_tp_info["step"][col] = 0 last_tp_info["step_idx"][col] = i if should_update_time_stop_nb( new_stop=_td_stop, upon_stop_update=_upon_stop_update ): last_td_info["init_idx"][col] = i last_td_info["init_position"][col] = position_now last_td_info["stop"][col] = _td_stop last_td_info["exit_price"][col] = _stop_exit_price last_td_info["exit_size"][col] = np.nan last_td_info["exit_size_type"][col] = -1 last_td_info["exit_type"][col] = _stop_exit_type last_td_info["order_type"][col] = _stop_order_type last_td_info["limit_delta"][col] = _stop_limit_delta last_td_info["delta_format"][col] = _delta_format last_td_info["time_delta_format"][col] = _time_delta_format last_td_info["ladder"][col] = stop_ladder last_td_info["step"][col] = 0 last_td_info["step_idx"][col] = i if should_update_time_stop_nb( new_stop=_dt_stop, upon_stop_update=_upon_stop_update ): last_dt_info["init_idx"][col] = i last_dt_info["init_position"][col] = position_now last_dt_info["stop"][col] = _dt_stop last_dt_info["exit_price"][col] = _stop_exit_price last_dt_info["exit_size"][col] = np.nan last_dt_info["exit_size_type"][col] = -1 last_dt_info["exit_type"][col] = _stop_exit_type last_dt_info["order_type"][col] = _stop_order_type last_dt_info["limit_delta"][col] = _stop_limit_delta last_dt_info["delta_format"][col] = _delta_format last_dt_info["time_delta_format"][col] = _time_delta_format last_dt_info["ladder"][col] = stop_ladder last_dt_info["step"][col] = 0 last_dt_info["step_idx"][col] = i if tsl_updated: # Update highest/lowest price if can_use_ohlc: _open = flex_select_nb(open_, i, col) _high = flex_select_nb(high_, i, col) _low = flex_select_nb(low_, i, col) _close = flex_select_nb(close_, i, col) _high, _low = resolve_hl_nb( open=_open, high=_high, low=_low, close=_close, ) else: _open = np.nan _high = _low = _close = flex_select_nb(close_, i, col) if tsl_updated: if position_now > 0: if _high > last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _high - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _high elif position_now < 0: if _low < last_tsl_info["peak_price"][col]: if last_tsl_info["delta_format"][col] == DeltaFormat.Target: last_tsl_info["stop"][col] = ( last_tsl_info["stop"][col] + _low - last_tsl_info["peak_price"][col] ) last_tsl_info["peak_idx"][col] = i last_tsl_info["peak_price"][col] = _low # Now becomes last last_position[col] = position_now last_debt[col] = debt_now last_locked_cash[col] = locked_cash_now if not np.isnan(val_price_now) or not ffill_val_price: last_val_price[col] = val_price_now if cash_sharing: last_cash[group] = cash_now last_free_cash[group] = free_cash_now last_value[group] = value_now last_return[group] = return_now else: last_cash[col] = cash_now last_free_cash[col] = free_cash_now last_value[col] = value_now last_return[col] = return_now # Call post-signal function post_signal_ctx = PostSignalContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, track_cash_deposits=track_cash_deposits, cash_deposits_out=cash_deposits_out, track_cash_earnings=track_cash_earnings, cash_earnings_out=cash_earnings_out, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, last_limit_info=last_limit_info, last_sl_info=last_sl_info, last_tsl_info=last_tsl_info, last_tp_info=last_tp_info, last_td_info=last_td_info, last_dt_info=last_dt_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, col=col, cash_before=cash_before, position_before=position_before, debt_before=debt_before, locked_cash_before=locked_cash_before, free_cash_before=free_cash_before, val_price_before=val_price_before, value_before=value_before, order_result=order_result, ) post_signal_func_nb(post_signal_ctx, *post_signal_args) for col in range(from_col, to_col): # Update valuation price using current close _close = flex_select_nb(close_, i, col) if not np.isnan(_close) or not ffill_val_price: last_val_price[col] = _close _cash_earnings = flex_select_nb(cash_earnings_, i, col) _cash_dividends = flex_select_nb(cash_dividends_, i, col) _cash_earnings += _cash_dividends * last_position[col] if cash_sharing: last_cash[group] += _cash_earnings last_free_cash[group] += _cash_earnings else: last_cash[col] += _cash_earnings last_free_cash[col] += _cash_earnings if track_cash_earnings: cash_earnings_out[i, col] += _cash_earnings # Update value and return if cash_sharing: group_value = last_cash[group] for col in range(from_col, to_col): if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[group] = group_value last_return[group] = get_return_nb( input_value=prev_close_value[group], output_value=last_value[group] - last_cash_deposits[group], ) prev_close_value[group] = last_value[group] else: for col in range(from_col, to_col): group_value = last_cash[col] if last_position[col] != 0: group_value += last_position[col] * last_val_price[col] last_value[col] = group_value last_return[col] = get_return_nb( input_value=prev_close_value[col], output_value=last_value[col] - last_cash_deposits[col], ) prev_close_value[col] = last_value[col] # Update open position stats if fill_pos_info: for col in range(from_col, to_col): update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col]) # Call post-segment function post_segment_ctx = SignalSegmentContext( target_shape=target_shape, group_lens=group_lens, cash_sharing=cash_sharing, index=index, freq=freq, open=open_, high=high_, low=low_, close=close_, init_cash=init_cash_, init_position=init_position_, init_price=init_price_, order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, track_cash_deposits=track_cash_deposits, cash_deposits_out=cash_deposits_out, track_cash_earnings=track_cash_earnings, cash_earnings_out=cash_earnings_out, in_outputs=in_outputs, last_cash=last_cash, last_position=last_position, last_debt=last_debt, last_locked_cash=last_locked_cash, last_free_cash=last_free_cash, last_val_price=last_val_price, last_value=last_value, last_return=last_return, last_pos_info=last_pos_info, last_limit_info=last_limit_info, last_sl_info=last_sl_info, last_tsl_info=last_tsl_info, last_tp_info=last_tp_info, last_td_info=last_td_info, last_dt_info=last_dt_info, sim_start=sim_start_, sim_end=sim_end_, group=group, group_len=group_len, from_col=from_col, to_col=to_col, i=i, ) post_segment_func_nb(post_segment_ctx, *post_segment_args) if i >= sim_end_[group] - 1: break sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb( target_shape=target_shape, group_lens=group_lens, sim_start=sim_start_, sim_end=sim_end_, allow_none=True, ) return prepare_sim_out_nb( order_records=order_records, order_counts=order_counts, log_records=log_records, log_counts=log_counts, cash_deposits=cash_deposits_out, cash_earnings=cash_earnings_out, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start_out, sim_end=sim_end_out, ) # %
@register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( entries=base_ch.FlexArraySlicer(axis=1), exits=base_ch.FlexArraySlicer(axis=1), direction=base_ch.FlexArraySlicer(axis=1), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def dir_to_ls_signals_nb( target_shape: tp.Shape, entries: tp.FlexArray2d, exits: tp.FlexArray2d, direction: tp.FlexArray2d, ) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]: """Convert direction-unaware to direction-aware signals.""" long_entries_out = np.empty(target_shape, dtype=np.bool_) long_exits_out = np.empty(target_shape, dtype=np.bool_) short_entries_out = np.empty(target_shape, dtype=np.bool_) short_exits_out = np.empty(target_shape, dtype=np.bool_) for col in prange(target_shape[1]): for i in range(target_shape[0]): is_entry = flex_select_nb(entries, i, col) is_exit = flex_select_nb(exits, i, col) _direction = flex_select_nb(direction, i, col) if _direction == Direction.LongOnly: long_entries_out[i, col] = is_entry long_exits_out[i, col] = is_exit short_entries_out[i, col] = False short_exits_out[i, col] = False elif _direction == Direction.ShortOnly: long_entries_out[i, col] = False long_exits_out[i, col] = False short_entries_out[i, col] = is_entry short_exits_out[i, col] = is_exit else: long_entries_out[i, col] = is_entry long_exits_out[i, col] = False short_entries_out[i, col] = is_exit short_exits_out[i, col] = False return long_entries_out, long_exits_out, short_entries_out, short_exits_out AdjustFuncT = tp.Callable[[SignalContext, tp.VarArg()], None] @register_jitted def no_adjust_func_nb(c: SignalContext, *args) -> None: """Placeholder `adjust_func_nb` that does nothing.""" return None # % # % # % # @register_jitted # def adjust_func_nb( # c: SignalContext, # ) -> None: # """Custom adjustment function.""" # return None # # # % # % # % # % # % blocks["adjust_func_nb"] @register_jitted def holding_enex_signal_func_nb( # % line.replace("holding_enex_signal_func_nb", "signal_func_nb") c: SignalContext, direction: int, close_at_end: bool, adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None adjust_args: tp.Args = (), ) -> tp.Tuple[bool, bool, bool, bool]: """`signal_func_nb` that returns direction-aware signals from holding.""" adjust_func_nb(c, *adjust_args) if c.last_position[c.col] == 0: if c.order_counts[c.col] == 0: if direction == Direction.ShortOnly: return False, False, True, False return True, False, False, False elif close_at_end and c.i == c.target_shape[0] - 1: if c.last_position[c.col] < 0: return False, False, False, True return False, True, False, False return False, False, False, False # % # % # % blocks["adjust_func_nb"] @register_jitted def dir_signal_func_nb( # % line.replace("dir_signal_func_nb", "signal_func_nb") c: SignalContext, entries: tp.FlexArray2d, exits: tp.FlexArray2d, direction: tp.FlexArray2d, from_ago: tp.FlexArray2d, adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None adjust_args: tp.Args = (), ) -> tp.Tuple[bool, bool, bool, bool]: """`signal_func_nb` that converts entries, exits, and direction into direction-aware signals. The direction of each pair of signals is taken from `direction` argument: * True, True, `Direction.LongOnly` -> True, True, False, False * True, True, `Direction.ShortOnly` -> False, False, True, True * True, True, `Direction.Both` -> True, False, True, False Best to use when the direction doesn't change throughout time. Prior to returning the signals, calls user-defined `adjust_func_nb`, which can be used to adjust stop values in the context. Must accept `vectorbtpro.portfolio.enums.SignalContext` and `*adjust_args`, and return nothing.""" adjust_func_nb(c, *adjust_args) _i = c.i - abs(flex_select_nb(from_ago, c.i, c.col)) if _i < 0: return False, False, False, False is_entry = flex_select_nb(entries, _i, c.col) is_exit = flex_select_nb(exits, _i, c.col) _direction = flex_select_nb(direction, _i, c.col) if _direction == Direction.LongOnly: return is_entry, is_exit, False, False if _direction == Direction.ShortOnly: return False, False, is_entry, is_exit return is_entry, False, is_exit, False # % # % # % blocks["adjust_func_nb"] @register_jitted def ls_signal_func_nb( # % line.replace("ls_signal_func_nb", "signal_func_nb") c: SignalContext, long_entries: tp.FlexArray2d, long_exits: tp.FlexArray2d, short_entries: tp.FlexArray2d, short_exits: tp.FlexArray2d, from_ago: tp.FlexArray2d, adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None adjust_args: tp.Args = (), ) -> tp.Tuple[bool, bool, bool, bool]: """`signal_func_nb` that gets an element of direction-aware signals. The direction is already built into the arrays. Best to use when the direction changes frequently (for example, if you have one indicator providing long signals and one providing short signals). Prior to returning the signals, calls user-defined `adjust_func_nb`, which can be used to adjust stop values in the context. Must accept `vectorbtpro.portfolio.enums.SignalContext` and `*adjust_args`, and return nothing.""" adjust_func_nb(c, *adjust_args) _i = c.i - abs(flex_select_nb(from_ago, c.i, c.col)) if _i < 0: return False, False, False, False is_long_entry = flex_select_nb(long_entries, _i, c.col) is_long_exit = flex_select_nb(long_exits, _i, c.col) is_short_entry = flex_select_nb(short_entries, _i, c.col) is_short_exit = flex_select_nb(short_exits, _i, c.col) return is_long_entry, is_long_exit, is_short_entry, is_short_exit # % # % # % blocks["adjust_func_nb"] @register_jitted def order_signal_func_nb( # % line.replace("order_signal_func_nb", "signal_func_nb") c: SignalContext, size: tp.FlexArray2d, price: tp.FlexArray2d, size_type: tp.FlexArray2d, direction: tp.FlexArray2d, min_size: tp.FlexArray2d, max_size: tp.FlexArray2d, val_price: tp.FlexArray2d, from_ago: tp.FlexArray2d, adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None adjust_args: tp.Args = (), ) -> tp.Tuple[bool, bool, bool, bool]: """`signal_func_nb` that converts orders into direction-aware signals. You must ensure that `size`, `size_type`, `min_size`, and `max_size` are writeable non-flexible arrays and accumulation is enabled.""" adjust_func_nb(c, *adjust_args) _i = c.i - abs(flex_select_nb(from_ago, c.i, c.col)) if _i < 0: return False, False, False, False order_size = float(flex_select_nb(size, _i, c.col)) if np.isnan(order_size): return False, False, False, False order_size_type = int(flex_select_nb(size_type, _i, c.col)) order_direction = int(flex_select_nb(direction, _i, c.col)) min_order_size = float(flex_select_nb(min_size, _i, c.col)) max_order_size = float(flex_select_nb(max_size, _i, c.col)) order_size = get_diraware_size_nb(order_size, order_direction) if ( order_size_type == SizeType.TargetAmount or order_size_type == SizeType.TargetValue or order_size_type == SizeType.TargetPercent or order_size_type == SizeType.TargetPercent100 ): order_price = flex_select_nb(price, _i, c.col) order_val_price = flex_select_nb(val_price, c.i, c.col) if np.isinf(order_val_price) and order_val_price > 0: if np.isinf(order_price) and order_price > 0: order_val_price = flex_select_nb(c.close, c.i, c.col) elif np.isinf(order_price) and order_price < 0: order_val_price = flex_select_nb(c.open, c.i, c.col) else: order_val_price = order_price elif np.isnan(order_val_price) or (np.isinf(order_val_price) and order_val_price < 0): order_val_price = c.last_val_price[c.col] order_size, _ = resolve_size_nb( size=order_size, size_type=order_size_type, position=c.last_position[c.col], val_price=order_val_price, value=c.last_value[c.group] if c.cash_sharing else c.last_value[c.col], ) if not np.isnan(min_order_size): min_order_size, _ = resolve_size_nb( size=min_order_size, size_type=order_size_type, position=c.last_position[c.col], val_price=order_val_price, value=c.last_value[c.group] if c.cash_sharing else c.last_value[c.col], as_requirement=True, ) if not np.isnan(max_order_size): max_order_size, _ = resolve_size_nb( size=max_order_size, size_type=order_size_type, position=c.last_position[c.col], val_price=order_val_price, value=c.last_value[c.group] if c.cash_sharing else c.last_value[c.col], as_requirement=True, ) order_size_type = SizeType.Amount size[_i, c.col] = abs(order_size) size_type[_i, c.col] = order_size_type min_size[_i, c.col] = min_order_size max_size[_i, c.col] = max_order_size else: size[_i, c.col] = abs(order_size) if order_size > 0: if c.last_position[c.col] < 0 and order_direction == Direction.ShortOnly: return False, False, False, True return True, False, False, False if order_size < 0: if c.last_position[c.col] > 0 and order_direction == Direction.LongOnly: return False, True, False, False return False, False, True, False return False, False, False, False # % # % # % blocks["post_segment_func_nb"] @register_jitted def save_post_segment_func_nb( # % line.replace("save_post_segment_func_nb", "post_segment_func_nb") c: SignalSegmentContext, save_state: bool = True, save_value: bool = True, save_returns: bool = True, ) -> None: """`post_segment_func_nb` that saves state, value, and returns.""" if save_state: for col in range(c.from_col, c.to_col): c.in_outputs.position[c.i, col] = c.last_position[col] c.in_outputs.debt[c.i, col] = c.last_debt[col] c.in_outputs.locked_cash[c.i, col] = c.last_locked_cash[col] if c.cash_sharing: c.in_outputs.cash[c.i, c.group] = c.last_cash[c.group] c.in_outputs.free_cash[c.i, c.group] = c.last_free_cash[c.group] else: for col in range(c.from_col, c.to_col): c.in_outputs.cash[c.i, col] = c.last_cash[col] c.in_outputs.free_cash[c.i, col] = c.last_free_cash[col] if save_value: if c.cash_sharing: c.in_outputs.value[c.i, c.group] = c.last_value[c.group] else: for col in range(c.from_col, c.to_col): c.in_outputs.value[c.i, col] = c.last_value[col] if save_returns: if c.cash_sharing: c.in_outputs.returns[c.i, c.group] = c.last_return[c.group] else: for col in range(c.from_col, c.to_col): c.in_outputs.returns[c.i, col] = c.last_return[col] # %
# ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for iterative portfolio simulation.""" from vectorbtpro import _typing as tp from vectorbtpro.base.flex_indexing import flex_select_nb from vectorbtpro.generic.nb.iter_ import ( iter_above_nb as _iter_above_nb, iter_below_nb as _iter_below_nb, iter_crossed_above_nb as _iter_crossed_above_nb, iter_crossed_below_nb as _iter_crossed_below_nb, ) from vectorbtpro.registries.jit_registry import register_jitted @register_jitted def select_nb( c: tp.NamedTuple, arr: tp.FlexArray2d, i: tp.Optional[int] = None, col: tp.Optional[int] = None, ) -> tp.Scalar: """Get the current element using flexible indexing. If any of the arguments are None, will use the respective value from the context.""" if i is None: _i = c.i else: _i = i if col is None: _col = c.col else: _col = col return flex_select_nb(arr, _i, _col) @register_jitted def select_from_col_nb( c: tp.NamedTuple, col: int, arr: tp.FlexArray2d, i: tp.Optional[int] = None, ) -> tp.Scalar: """Get the current element from a specific column using flexible indexing. If any of the arguments are None, will use the respective value from the context.""" if i is None: _i = c.i else: _i = i return flex_select_nb(arr, _i, col) @register_jitted def iter_above_nb( c: tp.NamedTuple, arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: tp.Optional[int] = None, col: tp.Optional[int] = None, ) -> bool: """Call `vectorbtpro.generic.nb.iter_.iter_above_nb` on the context. If any of the arguments are None, will use the respective value from the context.""" if i is None: _i = c.i else: _i = i if col is None: _col = c.col else: _col = col return _iter_above_nb(arr1, arr2, _i, _col) @register_jitted def iter_below_nb( c: tp.NamedTuple, arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: tp.Optional[int] = None, col: tp.Optional[int] = None, ) -> bool: """Call `vectorbtpro.generic.nb.iter_.iter_below_nb` on the context. If any of the arguments are None, will use the respective value from the context.""" if i is None: _i = c.i else: _i = i if col is None: _col = c.col else: _col = col return _iter_below_nb(arr1, arr2, _i, _col) @register_jitted def iter_crossed_above_nb( c: tp.NamedTuple, arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: tp.Optional[int] = None, col: tp.Optional[int] = None, ) -> bool: """Call `vectorbtpro.generic.nb.iter_.iter_crossed_above_nb` on the context. If any of the arguments are None, will use the respective value from the context.""" if i is None: _i = c.i else: _i = i if col is None: _col = c.col else: _col = col return _iter_crossed_above_nb(arr1, arr2, _i, _col) @register_jitted def iter_crossed_below_nb( c: tp.NamedTuple, arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: tp.Optional[int] = None, col: tp.Optional[int] = None, ) -> bool: """Call `vectorbtpro.generic.nb.iter_.iter_crossed_below_nb` on the context. If any of the arguments are None, will use the respective value from the context.""" if i is None: _i = c.i else: _i = i if col is None: _col = c.col else: _col = col return _iter_crossed_below_nb(arr1, arr2, _i, _col) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for portfolio records.""" from numba import prange from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb from vectorbtpro.portfolio.nb.core import * from vectorbtpro.records import chunking as records_ch from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.math_ import is_close_nb, is_close_or_less_nb, is_less_nb, add_nb from vectorbtpro.utils.template import Rep invalid_size_msg = "Encountered an order with size 0 or less" invalid_price_msg = "Encountered an order with price less than 0" @register_jitted(cache=True) def records_within_sim_range_nb( target_shape: tp.Shape, records: tp.RecordArray, col_arr: tp.Array1d, idx_arr: tp.Array1d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.RecordArray: """Return records within simulation range.""" out = np.empty(len(records), dtype=records.dtype) k = 0 sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=target_shape, sim_start=sim_start, sim_end=sim_end, ) for r in range(len(records)): _sim_start = sim_start_[col_arr[r]] _sim_end = sim_end_[col_arr[r]] if _sim_start >= _sim_end: continue if _sim_start <= idx_arr[r] < _sim_end: out[k] = records[r] k += 1 return out[:k] @register_jitted(cache=True) def apply_weights_to_orders_nb( order_records: tp.RecordArray, col_arr: tp.Array1d, weights: tp.Array1d, ) -> tp.RecordArray: """Apply weights to order records.""" order_records = order_records.copy() out = np.empty(len(order_records), dtype=order_records.dtype) k = 0 for r in range(len(order_records)): order_record = order_records[r] col = col_arr[r] if not np.isnan(weights[col]): order_record["size"] = weights[col] * order_record["size"] order_record["fees"] = weights[col] * order_record["fees"] if order_record["size"] != 0: out[k] = order_record k += 1 else: out[k] = order_record k += 1 return out[:k] @register_jitted(cache=True) def weighted_price_reduce_meta_nb( idxs: tp.Array1d, col: int, size_arr: tp.Array1d, price_arr: tp.Array1d, ) -> float: """Size-weighted price average.""" if len(idxs) == 0: return np.nan size_price_sum = 0.0 size_sum = 0.0 for i in range(len(idxs)): j = idxs[i] if not np.isnan(size_arr[j]) and not np.isnan(price_arr[j]): size_price_sum += size_arr[j] * price_arr[j] size_sum += size_arr[j] if size_sum == 0: return np.nan return size_price_sum / size_sum @register_jitted(cache=True) def fill_trade_record_nb( new_records: tp.RecordArray, r: int, col: int, size: float, entry_order_id: int, entry_idx: int, entry_price: float, entry_fees: float, exit_order_id: int, exit_idx: int, exit_price: float, exit_fees: float, direction: int, status: int, parent_id: int, ) -> None: """Fill a trade record.""" # Calculate PnL and return pnl, ret = get_trade_stats_nb(size, entry_price, entry_fees, exit_price, exit_fees, direction) # Save trade new_records["id"][r] = r new_records["col"][r] = col new_records["size"][r] = size new_records["entry_order_id"][r] = entry_order_id new_records["entry_idx"][r] = entry_idx new_records["entry_price"][r] = entry_price new_records["entry_fees"][r] = entry_fees new_records["exit_order_id"][r] = exit_order_id new_records["exit_idx"][r] = exit_idx new_records["exit_price"][r] = exit_price new_records["exit_fees"][r] = exit_fees new_records["pnl"][r] = pnl new_records["return"][r] = ret new_records["direction"][r] = direction new_records["status"][r] = status new_records["parent_id"][r] = parent_id new_records["parent_id"][r] = parent_id @register_jitted(cache=True) def fill_entry_trades_in_position_nb( order_records: tp.RecordArray, col_map: tp.GroupMap, col: int, sim_start: int, sim_end: int, first_c: int, last_c: int, init_price: float, first_entry_size: float, first_entry_fees: float, exit_idx: int, exit_size_sum: float, exit_gross_sum: float, exit_fees_sum: float, direction: int, status: int, parent_id: int, new_records: tp.RecordArray, r: int, ) -> int: """Fill entry trades located within a single position. Returns the next trade id.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens # Iterate over orders located within a single position for c in range(first_c, last_c + 1): if c == -1: entry_order_id = -1 entry_idx = -1 entry_size = first_entry_size entry_price = init_price entry_fees = first_entry_fees else: order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < sim_start or order_record["idx"] >= sim_end: continue entry_order_id = order_record["id"] entry_idx = order_record["idx"] entry_price = order_record["price"] order_side = order_record["side"] # Ignore exit orders if (direction == TradeDirection.Long and order_side == OrderSide.Sell) or ( direction == TradeDirection.Short and order_side == OrderSide.Buy ): continue if c == first_c: entry_size = first_entry_size entry_fees = first_entry_fees else: entry_size = order_record["size"] entry_fees = order_record["fees"] # Take a size-weighted average of exit price exit_price = exit_gross_sum / exit_size_sum # Take a fraction of exit fees size_fraction = entry_size / exit_size_sum exit_fees = size_fraction * exit_fees_sum # Fill the record if status == TradeStatus.Closed: exit_order_record = order_records[col_idxs[col_start_idxs[col] + last_c]] if exit_order_record["idx"] < sim_start or exit_order_record["idx"] >= sim_end: exit_order_id = -1 else: exit_order_id = exit_order_record["id"] else: exit_order_id = -1 fill_trade_record_nb( new_records, r, col, entry_size, entry_order_id, entry_idx, entry_price, entry_fees, exit_order_id, exit_idx, exit_price, exit_fees, direction, status, parent_id, ) r += 1 return r @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), close=base_ch.FlexArraySlicer(axis=1), col_map=base_ch.GroupMapSlicer(), init_position=base_ch.FlexArraySlicer(), init_price=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def get_entry_trades_nb( order_records: tp.RecordArray, close: tp.FlexArray2dLike, col_map: tp.GroupMap, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.RecordArray: """Fill entry trade records by aggregating order records. Entry trade records are buy orders in a long position and sell orders in a short position. Usage: ```pycon >>> from vectorbtpro import * >>> close = order_price = np.array([ ... [1, 6], ... [2, 5], ... [3, 4], ... [4, 3], ... [5, 2], ... [6, 1] ... ]) >>> size = np.array([ ... [1, -1], ... [0.1, -0.1], ... [-1, 1], ... [-0.1, 0.1], ... [1, -1], ... [-2, 2] ... ]) >>> target_shape = close.shape >>> group_lens = np.full(target_shape[1], 1) >>> init_cash = np.full(target_shape[1], 100) >>> sim_out = vbt.pf_nb.from_orders_nb( ... target_shape, ... group_lens, ... init_cash=init_cash, ... size=size, ... price=close, ... fees=np.asarray([[0.01]]), ... slippage=np.asarray([[0.01]]) ... ) >>> col_map = vbt.rec_nb.col_map_nb(sim_out.order_records['col'], target_shape[1]) >>> entry_trade_records = vbt.pf_nb.get_entry_trades_nb(sim_out.order_records, close, col_map) >>> pd.DataFrame.from_records(entry_trade_records) id col size entry_order_id entry_idx entry_price entry_fees \\ 0 0 0 1.0 0 0 1.01 0.01010 1 1 0 0.1 1 1 2.02 0.00202 2 2 0 1.0 4 4 5.05 0.05050 3 3 0 1.0 5 5 5.94 0.05940 4 0 1 1.0 0 0 5.94 0.05940 5 1 1 0.1 1 1 4.95 0.00495 6 2 1 1.0 4 4 1.98 0.01980 7 3 1 1.0 5 5 1.01 0.01010 exit_order_id exit_idx exit_price exit_fees pnl return \\ 0 3 3 3.060000 0.030600 2.009300 1.989406 1 3 3 3.060000 0.003060 0.098920 0.489703 2 5 5 5.940000 0.059400 0.780100 0.154475 3 -1 5 6.000000 0.000000 -0.119400 -0.020101 4 3 3 3.948182 0.039482 1.892936 0.318676 5 3 3 3.948182 0.003948 0.091284 0.184411 6 5 5 1.010000 0.010100 0.940100 0.474798 7 -1 5 1.000000 0.000000 -0.020100 -0.019901 direction status parent_id 0 0 1 0 1 0 1 0 2 0 1 1 3 1 0 2 4 1 1 0 5 1 1 0 6 1 1 1 7 0 0 2 ``` """ close_ = to_2d_array_nb(np.asarray(close)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens max_records = np.max(col_lens) + 1 new_records = np.empty((max_records, len(col_lens)), dtype=trade_dt) counts = np.full(len(col_lens), 0, dtype=int_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(close_.shape[0], col_lens.shape[0]), sim_start=sim_start, sim_end=sim_end, ) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: # Prepare initial position first_c = -1 in_position = True parent_id = 0 if _init_position >= 0: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = abs(_init_position) entry_gross_sum = abs(_init_position) * _init_price entry_fees_sum = 0.0 exit_size_sum = 0.0 exit_gross_sum = 0.0 exit_fees_sum = 0.0 first_entry_size = _init_position first_entry_fees = 0.0 else: in_position = False parent_id = -1 col_len = col_lens[col] if col_len == 0 and not in_position: continue last_id = -1 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] order_idx = order_record["idx"] order_size = order_record["size"] order_price = order_record["price"] order_fees = order_record["fees"] order_side = order_record["side"] if order_size <= 0.0: raise ValueError(invalid_size_msg) if order_price < 0.0: raise ValueError(invalid_price_msg) if not in_position: # New position opened first_c = c in_position = True parent_id += 1 if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = 0.0 entry_gross_sum = 0.0 entry_fees_sum = 0.0 exit_size_sum = 0.0 exit_gross_sum = 0.0 exit_fees_sum = 0.0 first_entry_size = order_size first_entry_fees = order_fees if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or ( direction == TradeDirection.Short and order_side == OrderSide.Sell ): # Position increased entry_size_sum += order_size entry_gross_sum += order_size * order_price entry_fees_sum += order_fees elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or ( direction == TradeDirection.Short and order_side == OrderSide.Buy ): if is_close_nb(exit_size_sum + order_size, entry_size_sum): # Position closed last_c = c in_position = False exit_size_sum = entry_size_sum exit_gross_sum += order_size * order_price exit_fees_sum += order_fees # Fill trade records counts[col] = fill_entry_trades_in_position_nb( order_records, col_map, col, _sim_start, _sim_end, first_c, last_c, _init_price, first_entry_size, first_entry_fees, order_idx, exit_size_sum, exit_gross_sum, exit_fees_sum, direction, TradeStatus.Closed, parent_id, new_records[:, col], counts[col], ) elif is_less_nb(exit_size_sum + order_size, entry_size_sum): # Position decreased exit_size_sum += order_size exit_gross_sum += order_size * order_price exit_fees_sum += order_fees else: # Position closed last_c = c remaining_size = add_nb(entry_size_sum, -exit_size_sum) exit_size_sum = entry_size_sum exit_gross_sum += remaining_size * order_price exit_fees_sum += remaining_size / order_size * order_fees # Fill trade records counts[col] = fill_entry_trades_in_position_nb( order_records, col_map, col, _sim_start, _sim_end, first_c, last_c, _init_price, first_entry_size, first_entry_fees, order_idx, exit_size_sum, exit_gross_sum, exit_fees_sum, direction, TradeStatus.Closed, parent_id, new_records[:, col], counts[col], ) # New position opened first_c = c parent_id += 1 if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = add_nb(order_size, -remaining_size) entry_gross_sum = entry_size_sum * order_price entry_fees_sum = entry_size_sum / order_size * order_fees first_entry_size = entry_size_sum first_entry_fees = entry_fees_sum exit_size_sum = 0.0 exit_gross_sum = 0.0 exit_fees_sum = 0.0 if in_position and is_less_nb(exit_size_sum, entry_size_sum): # Position hasn't been closed last_c = col_len - 1 remaining_size = add_nb(entry_size_sum, -exit_size_sum) exit_size_sum = entry_size_sum last_close = flex_select_nb(close_, _sim_end - 1, col) if np.isnan(last_close): for ri in range(_sim_end - 1, -1, -1): _close = flex_select_nb(close_, ri, col) if not np.isnan(_close): last_close = _close break exit_gross_sum += remaining_size * last_close exit_idx = _sim_end - 1 # Fill trade records counts[col] = fill_entry_trades_in_position_nb( order_records, col_map, col, _sim_start, _sim_end, first_c, last_c, _init_price, first_entry_size, first_entry_fees, exit_idx, exit_size_sum, exit_gross_sum, exit_fees_sum, direction, TradeStatus.Open, parent_id, new_records[:, col], counts[col], ) return generic_nb.repartition_nb(new_records, counts) @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), close=base_ch.FlexArraySlicer(axis=1), col_map=base_ch.GroupMapSlicer(), init_position=base_ch.FlexArraySlicer(), init_price=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def get_exit_trades_nb( order_records: tp.RecordArray, close: tp.FlexArray2dLike, col_map: tp.GroupMap, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.RecordArray: """Fill exit trade records by aggregating order records. Exit trade records are sell orders in a long position and buy orders in a short position. Usage: * Building upon the example in `get_exit_trades_nb`: ```pycon >>> exit_trade_records = vbt.pf_nb.get_exit_trades_nb(sim_out.order_records, close, col_map) >>> pd.DataFrame.from_records(exit_trade_records) id col size entry_order_id entry_idx entry_price entry_fees \\ 0 0 0 1.0 0 0 1.101818 0.011018 1 1 0 0.1 0 0 1.101818 0.001102 2 2 0 1.0 4 4 5.050000 0.050500 3 3 0 1.0 5 5 5.940000 0.059400 4 0 1 1.0 0 0 5.850000 0.058500 5 1 1 0.1 0 0 5.850000 0.005850 6 2 1 1.0 4 4 1.980000 0.019800 7 3 1 1.0 5 5 1.010000 0.010100 exit_order_id exit_idx exit_price exit_fees pnl return \\ 0 2 2 2.97 0.02970 1.827464 1.658589 1 3 3 3.96 0.00396 0.280756 2.548119 2 5 5 5.94 0.05940 0.780100 0.154475 3 -1 5 6.00 0.00000 -0.119400 -0.020101 4 2 2 4.04 0.04040 1.711100 0.292496 5 3 3 3.03 0.00303 0.273120 0.466872 6 5 5 1.01 0.01010 0.940100 0.474798 7 -1 5 1.00 0.00000 -0.020100 -0.019901 direction status parent_id 0 0 1 0 1 0 1 0 2 0 1 1 3 1 0 2 4 1 1 0 5 1 1 0 6 1 1 1 7 0 0 2 ``` """ close_ = to_2d_array_nb(np.asarray(close)) init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens max_records = np.max(col_lens) + 1 new_records = np.empty((max_records, len(col_lens)), dtype=trade_dt) counts = np.full(len(col_lens), 0, dtype=int_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(close_.shape[0], col_lens.shape[0]), sim_start=sim_start, sim_end=sim_end, ) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: # Prepare initial position in_position = True parent_id = 0 entry_order_id = -1 entry_idx = -1 if _init_position >= 0: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = abs(_init_position) entry_gross_sum = abs(_init_position) * _init_price entry_fees_sum = 0.0 else: in_position = False parent_id = -1 col_len = col_lens[col] if col_len == 0 and not in_position: continue last_id = -1 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] order_idx = order_record["idx"] order_id = order_record["id"] order_size = order_record["size"] order_price = order_record["price"] order_fees = order_record["fees"] order_side = order_record["side"] if order_size <= 0.0: raise ValueError(invalid_size_msg) if order_price < 0.0: raise ValueError(invalid_price_msg) if not in_position: # Trade opened in_position = True entry_order_id = order_id entry_idx = order_idx if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short parent_id += 1 entry_size_sum = 0.0 entry_gross_sum = 0.0 entry_fees_sum = 0.0 if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or ( direction == TradeDirection.Short and order_side == OrderSide.Sell ): # Position increased entry_size_sum += order_size entry_gross_sum += order_size * order_price entry_fees_sum += order_fees elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or ( direction == TradeDirection.Short and order_side == OrderSide.Buy ): if is_close_or_less_nb(order_size, entry_size_sum): # Trade closed if is_close_nb(order_size, entry_size_sum): exit_size = entry_size_sum else: exit_size = order_size exit_price = order_price exit_fees = order_fees exit_order_id = order_id exit_idx = order_idx # Take a size-weighted average of entry price entry_price = entry_gross_sum / entry_size_sum # Take a fraction of entry fees size_fraction = exit_size / entry_size_sum entry_fees = size_fraction * entry_fees_sum fill_trade_record_nb( new_records[:, col], counts[col], col, exit_size, entry_order_id, entry_idx, entry_price, entry_fees, exit_order_id, exit_idx, exit_price, exit_fees, direction, TradeStatus.Closed, parent_id, ) counts[col] += 1 if is_close_nb(order_size, entry_size_sum): # Position closed entry_order_id = -1 entry_idx = -1 direction = -1 in_position = False else: # Position decreased, previous orders have now less impact size_fraction = (entry_size_sum - order_size) / entry_size_sum entry_size_sum *= size_fraction entry_gross_sum *= size_fraction entry_fees_sum *= size_fraction else: # Trade reversed # Close current trade cl_exit_size = entry_size_sum cl_exit_price = order_price cl_exit_fees = cl_exit_size / order_size * order_fees cl_exit_order_id = order_id cl_exit_idx = order_idx # Take a size-weighted average of entry price entry_price = entry_gross_sum / entry_size_sum # Take a fraction of entry fees size_fraction = cl_exit_size / entry_size_sum entry_fees = size_fraction * entry_fees_sum fill_trade_record_nb( new_records[:, col], counts[col], col, cl_exit_size, entry_order_id, entry_idx, entry_price, entry_fees, cl_exit_order_id, cl_exit_idx, cl_exit_price, cl_exit_fees, direction, TradeStatus.Closed, parent_id, ) counts[col] += 1 # Open a new trade entry_size_sum = order_size - cl_exit_size entry_gross_sum = entry_size_sum * order_price entry_fees_sum = order_fees - cl_exit_fees entry_order_id = order_id entry_idx = order_idx if direction == TradeDirection.Long: direction = TradeDirection.Short else: direction = TradeDirection.Long parent_id += 1 if in_position and is_less_nb(-entry_size_sum, 0): # Trade hasn't been closed exit_size = entry_size_sum last_close = flex_select_nb(close_, _sim_end - 1, col) if np.isnan(last_close): for ri in range(_sim_end - 1, -1, -1): _close = flex_select_nb(close_, ri, col) if not np.isnan(_close): last_close = _close break exit_price = last_close exit_fees = 0.0 exit_order_id = -1 exit_idx = _sim_end - 1 # Take a size-weighted average of entry price entry_price = entry_gross_sum / entry_size_sum # Take a fraction of entry fees size_fraction = exit_size / entry_size_sum entry_fees = size_fraction * entry_fees_sum fill_trade_record_nb( new_records[:, col], counts[col], col, exit_size, entry_order_id, entry_idx, entry_price, entry_fees, exit_order_id, exit_idx, exit_price, exit_fees, direction, TradeStatus.Open, parent_id, ) counts[col] += 1 return generic_nb.repartition_nb(new_records, counts) @register_jitted(cache=True) def fill_position_record_nb(new_records: tp.RecordArray, r: int, trade_records: tp.RecordArray) -> None: """Fill a position record by aggregating trade records.""" # Aggregate trades col = trade_records["col"][0] size = np.sum(trade_records["size"]) entry_order_id = trade_records["entry_order_id"][0] entry_idx = trade_records["entry_idx"][0] entry_price = np.sum(trade_records["size"] * trade_records["entry_price"]) / size entry_fees = np.sum(trade_records["entry_fees"]) exit_order_id = trade_records["exit_order_id"][-1] exit_idx = trade_records["exit_idx"][-1] exit_price = np.sum(trade_records["size"] * trade_records["exit_price"]) / size exit_fees = np.sum(trade_records["exit_fees"]) direction = trade_records["direction"][-1] status = trade_records["status"][-1] pnl, ret = get_trade_stats_nb(size, entry_price, entry_fees, exit_price, exit_fees, direction) # Save position new_records["id"][r] = r new_records["col"][r] = col new_records["size"][r] = size new_records["entry_order_id"][r] = entry_order_id new_records["entry_idx"][r] = entry_idx new_records["entry_price"][r] = entry_price new_records["entry_fees"][r] = entry_fees new_records["exit_order_id"][r] = exit_order_id new_records["exit_idx"][r] = exit_idx new_records["exit_price"][r] = exit_price new_records["exit_fees"][r] = exit_fees new_records["pnl"][r] = pnl new_records["return"][r] = ret new_records["direction"][r] = direction new_records["status"][r] = status new_records["parent_id"][r] = r @register_jitted(cache=True) def copy_trade_record_nb(new_records: tp.RecordArray, r: int, trade_record: tp.Record) -> None: """Copy a trade record.""" new_records["id"][r] = r new_records["col"][r] = trade_record["col"] new_records["size"][r] = trade_record["size"] new_records["entry_order_id"][r] = trade_record["entry_order_id"] new_records["entry_idx"][r] = trade_record["entry_idx"] new_records["entry_price"][r] = trade_record["entry_price"] new_records["entry_fees"][r] = trade_record["entry_fees"] new_records["exit_order_id"][r] = trade_record["exit_order_id"] new_records["exit_idx"][r] = trade_record["exit_idx"] new_records["exit_price"][r] = trade_record["exit_price"] new_records["exit_fees"][r] = trade_record["exit_fees"] new_records["pnl"][r] = trade_record["pnl"] new_records["return"][r] = trade_record["return"] new_records["direction"][r] = trade_record["direction"] new_records["status"][r] = trade_record["status"] new_records["parent_id"][r] = r @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( trade_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def get_positions_nb(trade_records: tp.RecordArray, col_map: tp.GroupMap) -> tp.RecordArray: """Fill position records by aggregating trade records. Trades can be entry trades, exit trades, and even positions themselves - all will produce the same results. Usage: * Building upon the example in `get_exit_trades_nb`: ```pycon >>> col_map = vbt.rec_nb.col_map_nb(exit_trade_records['col'], target_shape[1]) >>> position_records = vbt.pf_nb.get_positions_nb(exit_trade_records, col_map) >>> pd.DataFrame.from_records(position_records) id col size entry_order_id entry_idx entry_price entry_fees \\ 0 0 0 1.1 0 0 1.101818 0.01212 1 1 0 1.0 4 4 5.050000 0.05050 2 2 0 1.0 5 5 5.940000 0.05940 3 0 1 1.1 0 0 5.850000 0.06435 4 1 1 1.0 4 4 1.980000 0.01980 5 2 1 1.0 5 5 1.010000 0.01010 exit_order_id exit_idx exit_price exit_fees pnl return \\ 0 3 3 3.060000 0.03366 2.10822 1.739455 1 5 5 5.940000 0.05940 0.78010 0.154475 2 -1 5 6.000000 0.00000 -0.11940 -0.020101 3 3 3 3.948182 0.04343 1.98422 0.308348 4 5 5 1.010000 0.01010 0.94010 0.474798 5 -1 5 1.000000 0.00000 -0.02010 -0.019901 direction status parent_id 0 0 1 0 1 0 1 1 2 1 0 2 3 1 1 0 4 1 1 1 5 0 0 2 ``` """ col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens new_records = np.empty((np.max(col_lens), len(col_lens)), dtype=trade_dt) counts = np.full(len(col_lens), 0, dtype=int_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue last_id = -1 last_position_id = -1 from_trade_r = -1 for c in range(col_len): trade_r = col_idxs[col_start_idxs[col] + c] record = trade_records[trade_r] if record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = record["id"] parent_id = record["parent_id"] if parent_id != last_position_id: if last_position_id != -1: if trade_r - from_trade_r > 1: fill_position_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r:trade_r]) else: # Speed up copy_trade_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r]) counts[col] += 1 from_trade_r = trade_r last_position_id = parent_id if trade_r - from_trade_r > 0: fill_position_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r : trade_r + 1]) else: # Speed up copy_trade_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r]) counts[col] += 1 return generic_nb.repartition_nb(new_records, counts) @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), close=ch.ArraySlicer(axis=1), col_map=base_ch.GroupMapSlicer(), init_position=base_ch.FlexArraySlicer(), init_price=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def get_long_view_orders_nb( order_records: tp.RecordArray, close: tp.Array2d, col_map: tp.GroupMap, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.RecordArray: """Get view of orders in long positions only.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) order_records = order_records.copy() out = np.empty(order_records.shape, dtype=order_records.dtype) r = 0 col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(close.shape[0], col_lens.shape[0]), sim_start=sim_start, sim_end=sim_end, ) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: # Prepare initial position in_position = True if _init_position >= 0: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = abs(_init_position) entry_gross_sum = abs(_init_position) * _init_price exit_size_sum = 0.0 exit_gross_sum = 0.0 else: in_position = False col_len = col_lens[col] if col_len == 0 and not in_position: continue last_id = -1 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] order_size = order_record["size"] order_price = order_record["price"] order_side = order_record["side"] order_fees = order_record["fees"] if order_size <= 0.0: raise ValueError(invalid_size_msg) if order_price < 0.0: raise ValueError(invalid_price_msg) if not in_position: # New position opened in_position = True if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = 0.0 entry_gross_sum = 0.0 exit_size_sum = 0.0 exit_gross_sum = 0.0 if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or ( direction == TradeDirection.Short and order_side == OrderSide.Sell ): # Position increased entry_size_sum += order_size entry_gross_sum += order_size * order_price if direction == TradeDirection.Long: out[r] = order_record r += 1 elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or ( direction == TradeDirection.Short and order_side == OrderSide.Buy ): if is_close_nb(exit_size_sum + order_size, entry_size_sum): # Position closed in_position = False exit_size_sum = entry_size_sum exit_gross_sum += order_size * order_price if direction == TradeDirection.Long: out[r] = order_record r += 1 elif is_less_nb(exit_size_sum + order_size, entry_size_sum): # Position decreased exit_size_sum += order_size exit_gross_sum += order_size * order_price if direction == TradeDirection.Long: out[r] = order_record r += 1 else: # Position closed remaining_size = add_nb(entry_size_sum, -exit_size_sum) # New position opened entry_size_sum = add_nb(order_size, -remaining_size) entry_gross_sum = entry_size_sum * order_price exit_size_sum = 0.0 exit_gross_sum = 0.0 if direction == TradeDirection.Long: out[r] = order_record out["size"][r] = remaining_size out["fees"][r] = remaining_size / order_size * order_fees r += 1 else: out[r] = order_record out["size"][r] = entry_size_sum out["fees"][r] = entry_size_sum / order_size * order_fees r += 1 if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short return out[:r] @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), close=ch.ArraySlicer(axis=1), col_map=base_ch.GroupMapSlicer(), init_position=base_ch.FlexArraySlicer(), init_price=base_ch.FlexArraySlicer(), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def get_short_view_orders_nb( order_records: tp.RecordArray, close: tp.Array2d, col_map: tp.GroupMap, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.RecordArray: """Get view of orders in short positions only.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) order_records = order_records.copy() out = np.empty(order_records.shape, dtype=order_records.dtype) r = 0 col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(close.shape[0], col_lens.shape[0]), sim_start=sim_start, sim_end=sim_end, ) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: # Prepare initial position in_position = True if _init_position >= 0: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = abs(_init_position) entry_gross_sum = abs(_init_position) * _init_price exit_size_sum = 0.0 exit_gross_sum = 0.0 else: in_position = False col_len = col_lens[col] if col_len == 0 and not in_position: continue last_id = -1 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] order_size = order_record["size"] order_price = order_record["price"] order_side = order_record["side"] order_fees = order_record["fees"] if order_size <= 0.0: raise ValueError(invalid_size_msg) if order_price < 0.0: raise ValueError(invalid_price_msg) if not in_position: # New position opened in_position = True if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = 0.0 entry_gross_sum = 0.0 exit_size_sum = 0.0 exit_gross_sum = 0.0 if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or ( direction == TradeDirection.Short and order_side == OrderSide.Sell ): # Position increased entry_size_sum += order_size entry_gross_sum += order_size * order_price if direction == TradeDirection.Short: out[r] = order_record r += 1 elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or ( direction == TradeDirection.Short and order_side == OrderSide.Buy ): if is_close_nb(exit_size_sum + order_size, entry_size_sum): # Position closed in_position = False exit_size_sum = entry_size_sum exit_gross_sum += order_size * order_price if direction == TradeDirection.Short: out[r] = order_record r += 1 elif is_less_nb(exit_size_sum + order_size, entry_size_sum): # Position decreased exit_size_sum += order_size exit_gross_sum += order_size * order_price if direction == TradeDirection.Short: out[r] = order_record r += 1 else: # Position closed remaining_size = add_nb(entry_size_sum, -exit_size_sum) # New position opened entry_size_sum = add_nb(order_size, -remaining_size) entry_gross_sum = entry_size_sum * order_price exit_size_sum = 0.0 exit_gross_sum = 0.0 if direction == TradeDirection.Short: out[r] = order_record out["size"][r] = remaining_size out["fees"][r] = remaining_size / order_size * order_fees r += 1 else: out[r] = order_record out["size"][r] = entry_size_sum out["fees"][r] = entry_size_sum / order_size * order_fees r += 1 if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short return out[:r] @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), close=ch.ArraySlicer(axis=1), col_map=base_ch.GroupMapSlicer(), feature=None, init_position=base_ch.FlexArraySlicer(), init_price=base_ch.FlexArraySlicer(), fill_closed_position=None, fill_exit_price=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def get_position_feature_nb( order_records: tp.RecordArray, close: tp.Array2d, col_map: tp.GroupMap, feature: int = PositionFeature.EntryPrice, init_position: tp.FlexArray1dLike = 0.0, init_price: tp.FlexArray1dLike = np.nan, fill_closed_position: bool = False, fill_exit_price: bool = True, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Get the position's feature at each time step. For the list of supported features see `vectorbtpro.portfolio.enums.PositionFeature`. If `fill_exit_price` is True and a part of the position is not closed yet, will fill the exit price as if the part was closed using the current close. If `fill_closed_position` is True, will forward-fill missing values with the prices of the previously closed position.""" init_position_ = to_1d_array_nb(np.asarray(init_position)) init_price_ = to_1d_array_nb(np.asarray(init_price)) out = np.full(close.shape, np.nan, dtype=float_) col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=(close.shape[0], col_lens.shape[0]), sim_start=sim_start, sim_end=sim_end, ) for col in prange(col_lens.shape[0]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_position = float(flex_select_1d_pc_nb(init_position_, col)) _init_price = float(flex_select_1d_pc_nb(init_price_, col)) if _init_position != 0: # Prepare initial position in_position = True was_in_position = True if _init_position >= 0: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = abs(_init_position) entry_gross_sum = abs(_init_position) * _init_price exit_size_sum = 0.0 exit_gross_sum = 0.0 last_order_idx = 0 else: in_position = False was_in_position = False col_len = col_lens[col] if col_len == 0 and not in_position: continue last_id = -1 for c in range(col_len): order_record = order_records[col_idxs[col_start_idxs[col] + c]] if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end: continue if order_record["id"] < last_id: raise ValueError("Ids must come in ascending order per column") last_id = order_record["id"] order_idx = order_record["idx"] order_size = order_record["size"] order_price = order_record["price"] order_side = order_record["side"] if order_size <= 0.0: raise ValueError(invalid_size_msg) if order_price < 0.0: raise ValueError(invalid_price_msg) if in_position: if feature == PositionFeature.EntryPrice: if entry_size_sum != 0: entry_price = entry_gross_sum / entry_size_sum out[last_order_idx:order_idx, col] = entry_price elif feature == PositionFeature.ExitPrice: if fill_exit_price: remaining_size = add_nb(entry_size_sum, -exit_size_sum) for i in range(last_order_idx, order_idx): open_exit_size_sum = entry_size_sum open_exit_gross_sum = exit_gross_sum + remaining_size * close[i, col] exit_price = open_exit_gross_sum / open_exit_size_sum out[i, col] = exit_price else: if exit_size_sum != 0: exit_price = exit_gross_sum / exit_size_sum out[last_order_idx:order_idx, col] = exit_price else: if was_in_position and fill_closed_position: if feature == PositionFeature.EntryPrice: if entry_size_sum != 0: entry_price = entry_gross_sum / entry_size_sum out[last_order_idx:order_idx, col] = entry_price elif feature == PositionFeature.ExitPrice: if exit_size_sum != 0: exit_price = exit_gross_sum / exit_size_sum out[last_order_idx:order_idx, col] = exit_price # New position opened in_position = True was_in_position = True if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = 0.0 entry_gross_sum = 0.0 exit_size_sum = 0.0 exit_gross_sum = 0.0 if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or ( direction == TradeDirection.Short and order_side == OrderSide.Sell ): # Position increased entry_size_sum += order_size entry_gross_sum += order_size * order_price elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or ( direction == TradeDirection.Short and order_side == OrderSide.Buy ): if is_close_nb(exit_size_sum + order_size, entry_size_sum): # Position closed in_position = False exit_size_sum = entry_size_sum exit_gross_sum += order_size * order_price elif is_less_nb(exit_size_sum + order_size, entry_size_sum): # Position decreased exit_size_sum += order_size exit_gross_sum += order_size * order_price else: # Position closed remaining_size = add_nb(entry_size_sum, -exit_size_sum) # New position opened if order_side == OrderSide.Buy: direction = TradeDirection.Long else: direction = TradeDirection.Short entry_size_sum = add_nb(order_size, -remaining_size) entry_gross_sum = entry_size_sum * order_price exit_size_sum = 0.0 exit_gross_sum = 0.0 last_order_idx = order_idx if in_position: if feature == PositionFeature.EntryPrice: if entry_size_sum != 0: entry_price = entry_gross_sum / entry_size_sum out[last_order_idx:_sim_end, col] = entry_price elif feature == PositionFeature.ExitPrice: if fill_exit_price: remaining_size = add_nb(entry_size_sum, -exit_size_sum) for i in range(last_order_idx, _sim_end): open_exit_size_sum = entry_size_sum open_exit_gross_sum = exit_gross_sum + remaining_size * close[i, col] exit_price = open_exit_gross_sum / open_exit_size_sum out[i, col] = exit_price else: if exit_size_sum != 0: exit_price = exit_gross_sum / exit_size_sum out[last_order_idx:_sim_end, col] = exit_price elif was_in_position and fill_closed_position: if feature == PositionFeature.EntryPrice: if entry_size_sum != 0: entry_price = entry_gross_sum / entry_size_sum out[last_order_idx:_sim_end, col] = entry_price elif feature == PositionFeature.ExitPrice: if exit_size_sum != 0: exit_price = exit_gross_sum / exit_size_sum out[last_order_idx:_sim_end, col] = exit_price return out @register_jitted(cache=True) def price_status_nb( records: tp.RecordArray, high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], ) -> tp.Array1d: """Return the status of the order's price related to high and low. See `vectorbtpro.portfolio.enums.OrderPriceStatus`.""" out = np.full(len(records), 0, dtype=int_) for i in range(len(records)): order = records[i] if high is not None: _high = float(flex_select_nb(high, order["idx"], order["col"])) else: _high = np.nan if low is not None: _low = flex_select_nb(low, order["idx"], order["col"]) else: _low = np.nan if not np.isnan(_high) and order["price"] > _high: out[i] = OrderPriceStatus.AboveHigh elif not np.isnan(_low) and order["price"] < _low: out[i] = OrderPriceStatus.BelowLow elif np.isnan(_high) or np.isnan(_low): out[i] = OrderPriceStatus.Unknown else: out[i] = OrderPriceStatus.OK return out @register_jitted(cache=True) def trade_winning_streak_nb(records: tp.RecordArray) -> tp.Array1d: """Return the current winning streak of each trade.""" out = np.full(len(records), 0, dtype=int_) curr_rank = 0 for i in range(len(records)): if records[i]["pnl"] > 0: curr_rank += 1 else: curr_rank = 0 out[i] = curr_rank return out @register_jitted(cache=True) def trade_losing_streak_nb(records: tp.RecordArray) -> tp.Array1d: """Return the current losing streak of each trade.""" out = np.full(len(records), 0, dtype=int_) curr_rank = 0 for i in range(len(records)): if records[i]["pnl"] < 0: curr_rank += 1 else: curr_rank = 0 out[i] = curr_rank return out @register_jitted(cache=True) def win_rate_reduce_nb(pnl_arr: tp.Array1d) -> float: """Win rate of a PnL array.""" if pnl_arr.shape[0] == 0: return np.nan win_count = 0 count = 0 for i in range(len(pnl_arr)): if not np.isnan(pnl_arr[i]): count += 1 if pnl_arr[i] > 0: win_count += 1 if count == 0: return np.nan return win_count / count @register_jitted(cache=True) def profit_factor_reduce_nb(pnl_arr: tp.Array1d) -> float: """Profit factor of a PnL array.""" if pnl_arr.shape[0] == 0: return np.nan win_sum = 0 loss_sum = 0 count = 0 for i in range(len(pnl_arr)): if not np.isnan(pnl_arr[i]): count += 1 if pnl_arr[i] > 0: win_sum += pnl_arr[i] elif pnl_arr[i] < 0: loss_sum += abs(pnl_arr[i]) if loss_sum == 0: return np.inf return win_sum / loss_sum @register_jitted(cache=True) def expectancy_reduce_nb(pnl_arr: tp.Array1d) -> float: """Expectancy of a PnL array.""" if pnl_arr.shape[0] == 0: return np.nan win_count = 0 win_sum = 0 loss_count = 0 loss_sum = 0 count = 0 for i in range(len(pnl_arr)): if not np.isnan(pnl_arr[i]): count += 1 if pnl_arr[i] > 0: win_count += 1 win_sum += pnl_arr[i] elif pnl_arr[i] < 0: loss_count += 1 loss_sum += abs(pnl_arr[i]) if count == 0: return np.nan win_rate = win_count / count if win_count == 0: win_mean = 0.0 else: win_mean = win_sum / win_count loss_rate = loss_count / count if loss_count == 0: loss_mean = 0.0 else: loss_mean = loss_sum / loss_count return win_rate * win_mean - loss_rate * loss_mean @register_jitted(cache=True) def sqn_reduce_nb(pnl_arr: tp.Array1d, ddof: int = 0) -> float: """SQN of a PnL array.""" count = generic_nb.nancnt_1d_nb(pnl_arr) mean = np.nanmean(pnl_arr) std = generic_nb.nanstd_1d_nb(pnl_arr, ddof=ddof) if std == 0: return np.nan return np.sqrt(count) * mean / std @register_jitted(cache=True) def trade_best_worst_price_nb( trade: tp.Record, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, idx_relative: bool = True, cont_idx: int = -1, one_iteration: bool = False, vmin: float = np.nan, vmax: float = np.nan, imin: int = -1, imax: int = -1, ) -> tp.Tuple[float, float, int, int]: """Best price, worst price, and their indices during a trade.""" from_i = trade["entry_idx"] to_i = trade["exit_idx"] trade_open = trade["status"] == TradeStatus.Open trade_long = trade["direction"] == TradeDirection.Long if cont_idx == -1 or cont_idx == from_i: cont_idx = from_i vmin = np.nan vmax = np.nan imin = -1 imax = -1 else: if trade_long: vmin, vmax, imin, imax = vmax, vmin, imax, imin if idx_relative: imin = from_i + imin imax = from_i + imax for i in range(cont_idx, to_i + 1): if i == from_i: if np.isnan(vmin) or trade["entry_price"] < vmin: vmin = trade["entry_price"] imin = i if np.isnan(vmax) or trade["entry_price"] > vmax: vmax = trade["entry_price"] imax = i if i > from_i or entry_price_open: if open is not None: _open = flex_select_nb(open, i, trade["col"]) if np.isnan(vmin) or _open < vmin: vmin = _open imin = i if np.isnan(vmax) or _open > vmax: vmax = _open imax = i if (i > from_i or entry_price_open) and (i < to_i or exit_price_close or trade_open): if low is not None: _low = flex_select_nb(low, i, trade["col"]) if np.isnan(vmin) or _low < vmin: vmin = _low imin = i if high is not None: _high = flex_select_nb(high, i, trade["col"]) if np.isnan(vmax) or _high > vmax: vmax = _high imax = i if i < to_i or exit_price_close or trade_open: _close = flex_select_nb(close, i, trade["col"]) if np.isnan(vmin) or _close < vmin: vmin = _close imin = i if np.isnan(vmax) or _close > vmax: vmax = _close imax = i if max_duration is not None: if from_i + max_duration == i: break if i == to_i: if np.isnan(vmin) or trade["exit_price"] < vmin: vmin = trade["exit_price"] imin = i if np.isnan(vmax) or trade["exit_price"] > vmax: vmax = trade["exit_price"] imax = i if one_iteration: break if idx_relative: imin = imin - from_i imax = imax - from_i if trade_long: return vmax, vmin, imax, imin return vmin, vmax, imin, imax @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict( records=ch.ArraySlicer(axis=0), open=None, high=None, low=None, close=None, entry_price_open=None, exit_price_close=None, max_duration=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def best_price_nb( records: tp.RecordArray, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, ) -> tp.Array1d: """Get best price by applying `trade_best_worst_price_nb` on each trade.""" out = np.empty(len(records), dtype=float_) for r in prange(len(records)): trade = records[r] out[r] = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, )[0] return out @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict( records=ch.ArraySlicer(axis=0), open=None, high=None, low=None, close=None, entry_price_open=None, exit_price_close=None, max_duration=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def worst_price_nb( records: tp.RecordArray, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, ) -> tp.Array1d: """Get worst price by applying `trade_best_worst_price_nb` on each trade.""" out = np.empty(len(records), dtype=float_) for r in prange(len(records)): trade = records[r] out[r] = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, )[1] return out @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict( records=ch.ArraySlicer(axis=0), open=None, high=None, low=None, close=None, entry_price_open=None, exit_price_close=None, max_duration=None, relative=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def best_price_idx_nb( records: tp.RecordArray, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, relative: bool = True, ) -> tp.Array1d: """Get index of best price by applying `trade_best_worst_price_nb` on each trade.""" out = np.empty(len(records), dtype=float_) for r in prange(len(records)): trade = records[r] out[r] = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, idx_relative=relative, )[2] return out @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict( records=ch.ArraySlicer(axis=0), open=None, high=None, low=None, close=None, entry_price_open=None, exit_price_close=None, max_duration=None, relative=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def worst_price_idx_nb( records: tp.RecordArray, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, relative: bool = True, ) -> tp.Array1d: """Get worst price by applying `trade_best_worst_price_nb` on each trade.""" out = np.empty(len(records), dtype=float_) for r in prange(len(records)): trade = records[r] out[r] = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, idx_relative=relative, )[3] return out @register_jitted(cache=True, tags={"can_parallel"}) def expanding_best_price_nb( records: tp.RecordArray, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, ) -> tp.Array2d: """Get expanding best price of each trade.""" if max_duration is None: _max_duration = 0 for r in range(len(records)): trade = records[r] trade_duration = trade["exit_idx"] - trade["entry_idx"] if trade_duration > _max_duration: _max_duration = trade_duration else: _max_duration = max_duration out = np.full((_max_duration + 1, len(records)), np.nan, dtype=float_) for r in prange(len(records)): trade = records[r] from_i = trade["entry_idx"] to_i = trade["exit_idx"] vmin = np.nan vmax = np.nan imin = -1 imax = -1 for i in range(from_i, to_i + 1): vmin, vmax, imin, imax = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, cont_idx=i, one_iteration=True, vmin=vmin, vmax=vmax, imin=imin, imax=imax, ) out[i - from_i, r] = vmin if max_duration is not None: if from_i + max_duration == i: break return out @register_jitted(cache=True, tags={"can_parallel"}) def expanding_worst_price_nb( records: tp.RecordArray, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, ) -> tp.Array2d: """Get expanding worst price of each trade.""" if max_duration is None: _max_duration = 0 for r in range(len(records)): trade = records[r] trade_duration = trade["exit_idx"] - trade["entry_idx"] if trade_duration > _max_duration: _max_duration = trade_duration else: _max_duration = max_duration out = np.full((_max_duration + 1, len(records)), np.nan, dtype=float_) for r in prange(len(records)): trade = records[r] from_i = trade["entry_idx"] to_i = trade["exit_idx"] vmin = np.nan vmax = np.nan imin = -1 imax = -1 for i in range(from_i, to_i + 1): vmin, vmax, imin, imax = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, cont_idx=i, one_iteration=True, vmin=vmin, vmax=vmax, imin=imin, imax=imax, ) out[i - from_i, r] = vmax if max_duration is not None: if from_i + max_duration == i: break return out @register_jitted(cache=True) def trade_mfe_nb( size: float, direction: int, entry_price: float, best_price: float, use_returns: bool = False, ) -> float: """Compute Maximum Favorable Excursion (MFE).""" if direction == TradeDirection.Long: if use_returns: return (best_price - entry_price) / entry_price return (best_price - entry_price) * size if use_returns: return (entry_price - best_price) / best_price return (entry_price - best_price) * size @register_chunkable( size=ch.ArraySizer(arg_query="size", axis=0), arg_take_spec=dict( size=ch.ArraySlicer(axis=0), direction=ch.ArraySlicer(axis=0), entry_price=ch.ArraySlicer(axis=0), best_price=ch.ArraySlicer(axis=0), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def mfe_nb( size: tp.Array1d, direction: tp.Array1d, entry_price: tp.Array1d, best_price: tp.Array1d, use_returns: bool = False, ) -> tp.Array1d: """Apply `trade_mfe_nb` on each trade.""" out = np.empty(size.shape[0], dtype=float_) for r in prange(size.shape[0]): out[r] = trade_mfe_nb( size=size[r], direction=direction[r], entry_price=entry_price[r], best_price=best_price[r], use_returns=use_returns, ) return out @register_jitted(cache=True) def trade_mae_nb( size: float, direction: int, entry_price: float, worst_price: float, use_returns: bool = False, ) -> float: """Compute Maximum Adverse Excursion (MAE).""" if direction == TradeDirection.Long: if use_returns: return (worst_price - entry_price) / entry_price return (worst_price - entry_price) * size if use_returns: return (entry_price - worst_price) / worst_price return (entry_price - worst_price) * size @register_chunkable( size=ch.ArraySizer(arg_query="size", axis=0), arg_take_spec=dict( size=ch.ArraySlicer(axis=0), direction=ch.ArraySlicer(axis=0), entry_price=ch.ArraySlicer(axis=0), worst_price=ch.ArraySlicer(axis=0), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def mae_nb( size: tp.Array1d, direction: tp.Array1d, entry_price: tp.Array1d, worst_price: tp.Array1d, use_returns: bool = False, ) -> tp.Array1d: """Apply `trade_mae_nb` on each trade.""" out = np.empty(size.shape[0], dtype=float_) for r in prange(size.shape[0]): out[r] = trade_mae_nb( size=size[r], direction=direction[r], entry_price=entry_price[r], worst_price=worst_price[r], use_returns=use_returns, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict( records=ch.ArraySlicer(axis=0), expanding_best_price=ch.ArraySlicer(axis=1), use_returns=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def expanding_mfe_nb( records: tp.RecordArray, expanding_best_price: tp.Array2d, use_returns: bool = False, ) -> tp.Array2d: """Get expanding MFE of each trade.""" out = np.empty_like(expanding_best_price, dtype=float_) for r in prange(expanding_best_price.shape[1]): for i in range(expanding_best_price.shape[0]): out[i, r] = trade_mfe_nb( size=records["size"][r], direction=records["direction"][r], entry_price=records["entry_price"][r], best_price=expanding_best_price[i, r], use_returns=use_returns, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict( records=ch.ArraySlicer(axis=0), expanding_worst_price=ch.ArraySlicer(axis=1), use_returns=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def expanding_mae_nb( records: tp.RecordArray, expanding_worst_price: tp.Array2d, use_returns: bool = False, ) -> tp.Array2d: """Get expanding MAE of each trade.""" out = np.empty_like(expanding_worst_price, dtype=float_) for r in prange(expanding_worst_price.shape[1]): for i in range(expanding_worst_price.shape[0]): out[i, r] = trade_mae_nb( size=records["size"][r], direction=records["direction"][r], entry_price=records["entry_price"][r], worst_price=expanding_worst_price[i, r], use_returns=use_returns, ) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), open=None, high=None, low=None, close=None, volatility=None, entry_price_open=None, exit_price_close=None, max_duration=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def edge_ratio_nb( records: tp.RecordArray, col_map: tp.GroupMap, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, volatility: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, ) -> tp.Array1d: """Get edge ratio of each column.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(len(col_lens), np.nan, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] ridxs = col_idxs[col_start_idx : col_start_idx + col_len] norm_mfe_sum = 0.0 norm_mfe_cnt = 0 norm_mae_sum = 0.0 norm_mae_cnt = 0 for r in ridxs: trade = records[r] best_price, worst_price, _, _ = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, ) mfe = abs( trade_mfe_nb( size=trade["size"], direction=trade["direction"], entry_price=trade["entry_price"], best_price=best_price, use_returns=False, ) ) mae = abs( trade_mae_nb( size=trade["size"], direction=trade["direction"], entry_price=trade["entry_price"], worst_price=worst_price, use_returns=False, ) ) _volatility = flex_select_nb(volatility, trade["entry_idx"], trade["col"]) if _volatility == 0: norm_mfe = np.nan norm_mae = np.nan else: norm_mfe = mfe / _volatility norm_mae = mae / _volatility if not np.isnan(norm_mfe): norm_mfe_sum += norm_mfe norm_mfe_cnt += 1 if not np.isnan(norm_mae): norm_mae_sum += norm_mae norm_mae_cnt += 1 if norm_mfe_cnt == 0: mean_mfe = np.nan else: mean_mfe = norm_mfe_sum / norm_mfe_cnt if norm_mae_cnt == 0: mean_mae = np.nan else: mean_mae = norm_mae_sum / norm_mae_cnt if mean_mae == 0: out[col] = np.nan else: out[col] = mean_mfe / mean_mae return out @register_jitted(cache=True, tags={"can_parallel"}) def running_edge_ratio_nb( records: tp.RecordArray, col_map: tp.GroupMap, open: tp.Optional[tp.FlexArray2d], high: tp.Optional[tp.FlexArray2d], low: tp.Optional[tp.FlexArray2d], close: tp.FlexArray2d, volatility: tp.FlexArray2d, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, incl_shorter: bool = False, ) -> tp.Array2d: """Get running edge ratio of each column.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens if max_duration is None: _max_duration = 0 for r in range(len(records)): trade = records[r] trade_duration = trade["exit_idx"] - trade["entry_idx"] if trade_duration > _max_duration: _max_duration = trade_duration else: _max_duration = max_duration out = np.full((_max_duration, len(col_lens)), np.nan, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] ridxs = col_idxs[col_start_idx : col_start_idx + col_len] for k in range(_max_duration): norm_mfe_sum = 0.0 norm_mfe_cnt = 0 norm_mae_sum = 0.0 norm_mae_cnt = 0 for r in ridxs: trade = records[r] if not incl_shorter: trade_duration = trade["exit_idx"] - trade["entry_idx"] if trade_duration < k + 1: continue best_price, worst_price, _, _ = trade_best_worst_price_nb( trade=trade, open=open, high=high, low=low, close=close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=k + 1, ) mfe = abs( trade_mfe_nb( size=trade["size"], direction=trade["direction"], entry_price=trade["entry_price"], best_price=best_price, use_returns=False, ) ) mae = abs( trade_mae_nb( size=trade["size"], direction=trade["direction"], entry_price=trade["entry_price"], worst_price=worst_price, use_returns=False, ) ) _volatility = flex_select_nb(volatility, trade["entry_idx"], trade["col"]) if _volatility == 0: norm_mfe = np.nan norm_mae = np.nan else: norm_mfe = mfe / _volatility norm_mae = mae / _volatility if not np.isnan(norm_mfe): norm_mfe_sum += norm_mfe norm_mfe_cnt += 1 if not np.isnan(norm_mae): norm_mae_sum += norm_mae norm_mae_cnt += 1 if norm_mfe_cnt == 0: mean_mfe = np.nan else: mean_mfe = norm_mfe_sum / norm_mfe_cnt if norm_mae_cnt == 0: mean_mae = np.nan else: mean_mae = norm_mae_sum / norm_mae_cnt if mean_mae == 0: out[k, col] = np.nan else: out[k, col] = mean_mfe / mean_mae return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules with classes and utilities for portfolio optimization.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.portfolio.pfopt.base import * from vectorbtpro.portfolio.pfopt.nb import * from vectorbtpro.portfolio.pfopt.records import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base functions and classes for portfolio optimization.""" import inspect import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.indexes import combine_indexes, stack_indexes, select_levels from vectorbtpro.base.indexing import point_idxr_defaults, range_idxr_defaults from vectorbtpro.base.merging import row_stack_arrays from vectorbtpro.base.reshaping import to_pd_array, to_1d_array, to_2d_array, to_dict, broadcast_array_to from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.data.base import Data from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.generic.enums import RangeStatus from vectorbtpro.generic.splitting.base import Splitter, Takeable from vectorbtpro.portfolio.enums import Direction from vectorbtpro.portfolio.pfopt import nb from vectorbtpro.portfolio.pfopt.records import AllocRanges, AllocPoints from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.returns.accessors import ReturnsAccessor from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.annotations import has_annotatables from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import hybrid_method from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.execution import Task, execute from vectorbtpro.utils.params import Param, combine_params, Parameterizer from vectorbtpro.utils.parsing import ( get_func_arg_names, annotate_args, flatten_ann_args, unflatten_ann_args, ann_args_to_args, ) from vectorbtpro.utils.pickling import pdict from vectorbtpro.utils.random_ import set_seed_nb from vectorbtpro.utils.template import substitute_templates, Rep, RepFunc, CustomTemplate from vectorbtpro.utils.warnings_ import warn, warn_stdout, WarningsFiltered if tp.TYPE_CHECKING: from vectorbtpro.portfolio.base import Portfolio as PortfolioT else: PortfolioT = "Portfolio" try: if not tp.TYPE_CHECKING: raise ImportError from pypfopt.base_optimizer import BaseOptimizer as BaseOptimizerT except ImportError: BaseOptimizerT = "BaseOptimizer" try: if not tp.TYPE_CHECKING: raise ImportError from riskfolio import Portfolio as RPortfolio, HCPortfolio as RHCPortfolio RPortfolioT = tp.TypeVar("RPortfolioT", bound=tp.Union[RPortfolio, RHCPortfolio]) except ImportError: RPortfolioT = "RPortfolio" try: if not tp.TYPE_CHECKING: raise ImportError from universal.algo import Algo from universal.result import AlgoResult AlgoT = tp.TypeVar("AlgoT", bound=Algo) AlgoResultT = tp.TypeVar("AlgoResultT", bound=AlgoResult) except ImportError: AlgoT = "Algo" AlgoResultT = "AlgoResult" __all__ = [ "pfopt_func_dict", "pypfopt_optimize", "riskfolio_optimize", "PortfolioOptimizer", "PFO", ] __pdoc__ = {} # ############# PyPortfolioOpt ############# # class pfopt_func_dict(pdict): """Dict that contains optimization functions as keys. Keys can be functions themselves, their names, or `_def` for the default value.""" pass def select_pfopt_func_kwargs( pypfopt_func: tp.Callable, kwargs: tp.Union[None, tp.Kwargs, pfopt_func_dict] = None, ) -> tp.Kwargs: """Select keyword arguments belonging to `pypfopt_func`.""" if kwargs is None: return {} if isinstance(kwargs, pfopt_func_dict): if pypfopt_func in kwargs: _kwargs = kwargs[pypfopt_func] elif pypfopt_func.__name__ in kwargs: _kwargs = kwargs[pypfopt_func.__name__] elif "_def" in kwargs: _kwargs = kwargs["_def"] else: _kwargs = {} else: _kwargs = {} for k, v in kwargs.items(): if isinstance(v, pfopt_func_dict): if pypfopt_func in v: _kwargs[k] = v[pypfopt_func] elif pypfopt_func.__name__ in v: _kwargs[k] = v[pypfopt_func.__name__] elif "_def" in v: _kwargs[k] = v["_def"] else: _kwargs[k] = v return _kwargs def resolve_pypfopt_func_kwargs( pypfopt_func: tp.Callable, cache: tp.KwargsLike = None, var_kwarg_names: tp.Optional[tp.Iterable[str]] = None, used_arg_names: tp.Optional[tp.Set[str]] = None, **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments passed to any optimization function with the layout of PyPortfolioOpt. Parses the signature of `pypfopt_func`, and for each accepted argument, looks for an argument with the same name in `kwargs`. If not found, tries to resolve that argument using other arguments or by calling other optimization functions. Argument `frequency` gets resolved with (global) `freq` and `year_freq` using `vectorbtpro.returns.accessors.ReturnsAccessor.get_ann_factor`. Any argument in `kwargs` can be wrapped using `pfopt_func_dict` to define the argument per function rather than globally. !!! note When providing custom functions, make sure that the arguments they accept are visible in the signature (that is, no variable arguments) and have the same naming as in PyPortfolioOpt. Functions `market_implied_prior_returns` and `BlackLittermanModel.bl_weights` take `risk_aversion`, which is different from arguments with the same name in other functions. To set it, pass `delta`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pypfopt") signature = inspect.signature(pypfopt_func) kwargs = select_pfopt_func_kwargs(pypfopt_func, kwargs) if cache is None: cache = {} arg_names = get_func_arg_names(pypfopt_func) if len(arg_names) == 0: return {} if used_arg_names is None: used_arg_names = set() pass_kwargs = dict() def _process_arg(arg_name, arg_value): orig_arg_name = arg_name if pypfopt_func.__name__ in ("market_implied_prior_returns", "bl_weights"): if arg_name == "risk_aversion": # In some methods, risk_aversion is expected as array and means delta arg_name = "delta" def _get_kwarg(*args): used_arg_names.add(args[0]) return kwargs.get(*args) def _get_prices(): prices = None if "prices" in cache: prices = cache["prices"] elif "prices" in kwargs: if not _get_kwarg("returns_data", False): prices = _get_kwarg("prices") return prices def _get_returns(): returns = None if "returns" in cache: returns = cache["returns"] elif "returns" in kwargs: returns = _get_kwarg("returns") elif "prices" in kwargs and _get_kwarg("returns_data", False): returns = _get_kwarg("prices") return returns def _prices_from_returns(): from pypfopt.expected_returns import prices_from_returns cache["prices"] = prices_from_returns(_get_returns(), _get_kwarg("log_returns", False)) return cache["prices"] def _returns_from_prices(): from pypfopt.expected_returns import returns_from_prices cache["returns"] = returns_from_prices(_get_prices(), _get_kwarg("log_returns", False)) return cache["returns"] if arg_name == "expected_returns": if arg_name in kwargs: used_arg_names.add(arg_name) if "expected_returns" not in cache: cache["expected_returns"] = resolve_pypfopt_expected_returns( cache=cache, used_arg_names=used_arg_names, **kwargs, ) pass_kwargs[orig_arg_name] = cache["expected_returns"] elif arg_name == "cov_matrix": if arg_name in kwargs: used_arg_names.add(arg_name) if "cov_matrix" not in cache: cache["cov_matrix"] = resolve_pypfopt_cov_matrix( cache=cache, used_arg_names=used_arg_names, **kwargs, ) pass_kwargs[orig_arg_name] = cache["cov_matrix"] elif arg_name == "optimizer": if arg_name in kwargs: used_arg_names.add(arg_name) if "optimizer" not in cache: cache["optimizer"] = resolve_pypfopt_optimizer( cache=cache, used_arg_names=used_arg_names, **kwargs, ) pass_kwargs[orig_arg_name] = cache["optimizer"] if orig_arg_name not in pass_kwargs: if arg_name in kwargs: if arg_name == "market_prices": if pypfopt_func.__name__ != "market_implied_risk_aversion" and checks.is_series( _get_kwarg(arg_name) ): pass_kwargs[orig_arg_name] = _get_kwarg(arg_name).to_frame().copy(deep=False) else: pass_kwargs[orig_arg_name] = _get_kwarg(arg_name).copy(deep=False) else: pass_kwargs[orig_arg_name] = _get_kwarg(arg_name) else: if arg_name == "frequency": ann_factor = ReturnsAccessor.get_ann_factor( year_freq=_get_kwarg("year_freq", None), freq=_get_kwarg("freq", None), ) if ann_factor is not None: pass_kwargs[orig_arg_name] = ann_factor elif arg_name == "prices": if "returns_data" in arg_names: if "returns_data" in kwargs: if _get_kwarg("returns_data", False): if _get_returns() is not None: pass_kwargs[orig_arg_name] = _get_returns() elif _get_prices() is not None: pass_kwargs[orig_arg_name] = _returns_from_prices() else: if _get_prices() is not None: pass_kwargs[orig_arg_name] = _get_prices() elif _get_returns() is not None: pass_kwargs[orig_arg_name] = _prices_from_returns() else: if _get_prices() is not None: pass_kwargs[orig_arg_name] = _get_prices() pass_kwargs["returns_data"] = False elif _get_returns() is not None: pass_kwargs[orig_arg_name] = _get_returns() pass_kwargs["returns_data"] = True else: if _get_prices() is not None: pass_kwargs[orig_arg_name] = _get_prices() elif _get_returns() is not None: pass_kwargs[orig_arg_name] = _prices_from_returns() elif arg_name == "returns": if _get_returns() is not None: pass_kwargs[orig_arg_name] = _get_returns() elif _get_prices() is not None: pass_kwargs[orig_arg_name] = _returns_from_prices() elif arg_name == "latest_prices": from pypfopt.discrete_allocation import get_latest_prices if _get_prices() is not None: pass_kwargs[orig_arg_name] = cache["latest_prices"] = get_latest_prices(_get_prices()) elif _get_returns() is not None: pass_kwargs[orig_arg_name] = cache["latest_prices"] = get_latest_prices(_prices_from_returns()) elif arg_name == "delta": if "delta" not in cache: from pypfopt.black_litterman import market_implied_risk_aversion cache["delta"] = resolve_pypfopt_func_call( market_implied_risk_aversion, cache=cache, used_arg_names=used_arg_names, **kwargs, ) pass_kwargs[orig_arg_name] = cache["delta"] elif arg_name == "pi": if "pi" not in cache: from pypfopt.black_litterman import market_implied_prior_returns cache["pi"] = resolve_pypfopt_func_call( market_implied_prior_returns, cache=cache, used_arg_names=used_arg_names, **kwargs, ) pass_kwargs[orig_arg_name] = cache["pi"] if orig_arg_name not in pass_kwargs: if arg_value.default != inspect.Parameter.empty: pass_kwargs[orig_arg_name] = arg_value.default for arg_name, arg_value in signature.parameters.items(): if arg_value.kind == inspect.Parameter.VAR_POSITIONAL: raise TypeError(f"Variable positional arguments in {pypfopt_func} cannot be parsed") elif arg_value.kind == inspect.Parameter.VAR_KEYWORD: if var_kwarg_names is None: var_kwarg_names = [] for var_arg_name in var_kwarg_names: _process_arg(var_arg_name, arg_value) else: _process_arg(arg_name, arg_value) return pass_kwargs def resolve_pypfopt_func_call(pypfopt_func: tp.Callable, **kwargs) -> tp.Any: """Resolve arguments using `resolve_pypfopt_func_kwargs` and call the function with that arguments.""" return pypfopt_func(**resolve_pypfopt_func_kwargs(pypfopt_func, **kwargs)) def resolve_pypfopt_expected_returns( expected_returns: tp.Union[tp.Callable, tp.AnyArray, str] = "mean_historical_return", **kwargs, ) -> tp.AnyArray: """Resolve the expected returns. `expected_returns` can be an array, an attribute of `pypfopt.expected_returns`, a function, or one of the following options: * 'mean_historical_return': `pypfopt.expected_returns.mean_historical_return` * 'ema_historical_return': `pypfopt.expected_returns.ema_historical_return` * 'capm_return': `pypfopt.expected_returns.capm_return` * 'bl_returns': `pypfopt.black_litterman.BlackLittermanModel.bl_returns` Any function is resolved using `resolve_pypfopt_func_call`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pypfopt") if isinstance(expected_returns, str): if expected_returns.lower() == "mean_historical_return": from pypfopt.expected_returns import mean_historical_return return resolve_pypfopt_func_call(mean_historical_return, **kwargs) if expected_returns.lower() == "ema_historical_return": from pypfopt.expected_returns import ema_historical_return return resolve_pypfopt_func_call(ema_historical_return, **kwargs) if expected_returns.lower() == "capm_return": from pypfopt.expected_returns import capm_return return resolve_pypfopt_func_call(capm_return, **kwargs) if expected_returns.lower() == "bl_returns": from pypfopt.black_litterman import BlackLittermanModel return resolve_pypfopt_func_call( BlackLittermanModel, var_kwarg_names=["market_caps", "risk_free_rate"], **kwargs, ).bl_returns() import pypfopt.expected_returns if hasattr(pypfopt.expected_returns, expected_returns): return resolve_pypfopt_func_call(getattr(pypfopt.expected_returns, expected_returns), **kwargs) raise NotImplementedError("Return model '{}' is not supported".format(expected_returns)) if callable(expected_returns): return resolve_pypfopt_func_call(expected_returns, **kwargs) return expected_returns def resolve_pypfopt_cov_matrix( cov_matrix: tp.Union[tp.Callable, tp.AnyArray, str] = "ledoit_wolf", **kwargs, ) -> tp.AnyArray: """Resolve the covariance matrix. `cov_matrix` can be an array, an attribute of `pypfopt.risk_models`, a function, or one of the following options: * 'sample_cov': `pypfopt.risk_models.sample_cov` * 'semicovariance' or 'semivariance': `pypfopt.risk_models.semicovariance` * 'exp_cov': `pypfopt.risk_models.exp_cov` * 'ledoit_wolf' or 'ledoit_wolf_constant_variance': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf` with 'constant_variance' as shrinkage factor * 'ledoit_wolf_single_factor': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf` with 'single_factor' as shrinkage factor * 'ledoit_wolf_constant_correlation': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf` with 'constant_correlation' as shrinkage factor * 'oracle_approximating': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf` with 'oracle_approximating' as shrinkage factor Any function is resolved using `resolve_pypfopt_func_call`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pypfopt") if isinstance(cov_matrix, str): if cov_matrix.lower() == "sample_cov": from pypfopt.risk_models import sample_cov return resolve_pypfopt_func_call(sample_cov, var_kwarg_names=["fix_method"], **kwargs) if cov_matrix.lower() == "semicovariance" or cov_matrix.lower() == "semivariance": from pypfopt.risk_models import semicovariance return resolve_pypfopt_func_call(semicovariance, var_kwarg_names=["fix_method"], **kwargs) if cov_matrix.lower() == "exp_cov": from pypfopt.risk_models import exp_cov return resolve_pypfopt_func_call(exp_cov, var_kwarg_names=["fix_method"], **kwargs) if cov_matrix.lower() == "ledoit_wolf" or cov_matrix.lower() == "ledoit_wolf_constant_variance": from pypfopt.risk_models import CovarianceShrinkage return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).ledoit_wolf() if cov_matrix.lower() == "ledoit_wolf_single_factor": from pypfopt.risk_models import CovarianceShrinkage return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).ledoit_wolf( shrinkage_target="single_factor" ) if cov_matrix.lower() == "ledoit_wolf_constant_correlation": from pypfopt.risk_models import CovarianceShrinkage return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).ledoit_wolf( shrinkage_target="constant_correlation" ) if cov_matrix.lower() == "oracle_approximating": from pypfopt.risk_models import CovarianceShrinkage return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).oracle_approximating() import pypfopt.risk_models if hasattr(pypfopt.risk_models, cov_matrix): return resolve_pypfopt_func_call(getattr(pypfopt.risk_models, cov_matrix), **kwargs) raise NotImplementedError("Risk model '{}' is not supported".format(cov_matrix)) if callable(cov_matrix): return resolve_pypfopt_func_call(cov_matrix, **kwargs) return cov_matrix def resolve_pypfopt_optimizer( optimizer: tp.Union[tp.Callable, BaseOptimizerT, str] = "efficient_frontier", **kwargs, ) -> BaseOptimizerT: """Resolve the optimizer. `optimizer` can be an instance of `pypfopt.base_optimizer.BaseOptimizer`, an attribute of `pypfopt`, a subclass of `pypfopt.base_optimizer.BaseOptimizer`, or one of the following options: * 'efficient_frontier': `pypfopt.efficient_frontier.EfficientFrontier` * 'efficient_cdar': `pypfopt.efficient_frontier.EfficientCDaR` * 'efficient_cvar': `pypfopt.efficient_frontier.EfficientCVaR` * 'efficient_semivariance': `pypfopt.efficient_frontier.EfficientSemivariance` * 'black_litterman' or 'bl': `pypfopt.black_litterman.BlackLittermanModel` * 'hierarchical_portfolio', 'hrpopt', or 'hrp': `pypfopt.hierarchical_portfolio.HRPOpt` * 'cla': `pypfopt.cla.CLA` Any function is resolved using `resolve_pypfopt_func_call`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pypfopt") from pypfopt.base_optimizer import BaseOptimizer if isinstance(optimizer, str): if optimizer.lower() == "efficient_frontier": from pypfopt.efficient_frontier import EfficientFrontier return resolve_pypfopt_func_call(EfficientFrontier, **kwargs) if optimizer.lower() == "efficient_cdar": from pypfopt.efficient_frontier import EfficientCDaR return resolve_pypfopt_func_call(EfficientCDaR, **kwargs) if optimizer.lower() == "efficient_cvar": from pypfopt.efficient_frontier import EfficientCVaR return resolve_pypfopt_func_call(EfficientCVaR, **kwargs) if optimizer.lower() == "efficient_semivariance": from pypfopt.efficient_frontier import EfficientSemivariance return resolve_pypfopt_func_call(EfficientSemivariance, **kwargs) if optimizer.lower() == "black_litterman" or optimizer.lower() == "bl": from pypfopt.black_litterman import BlackLittermanModel return resolve_pypfopt_func_call( BlackLittermanModel, var_kwarg_names=["market_caps", "risk_free_rate"], **kwargs, ) if optimizer.lower() == "hierarchical_portfolio" or optimizer.lower() == "hrpopt" or optimizer.lower() == "hrp": from pypfopt.hierarchical_portfolio import HRPOpt return resolve_pypfopt_func_call(HRPOpt, **kwargs) if optimizer.lower() == "cla": from pypfopt.cla import CLA return resolve_pypfopt_func_call(CLA, **kwargs) import pypfopt if hasattr(pypfopt, optimizer): return resolve_pypfopt_func_call(getattr(pypfopt, optimizer), **kwargs) raise NotImplementedError("Optimizer '{}' is not supported".format(optimizer)) if isinstance(optimizer, type) and issubclass(optimizer, BaseOptimizer): return resolve_pypfopt_func_call(optimizer, **kwargs) if isinstance(optimizer, BaseOptimizer): return optimizer raise NotImplementedError("Optimizer {} is not supported".format(optimizer)) def pypfopt_optimize( target: tp.Optional[tp.Union[tp.Callable, str]] = None, target_is_convex: tp.Optional[bool] = None, weights_sum_to_one: tp.Optional[bool] = None, target_constraints: tp.Optional[tp.List[tp.Kwargs]] = None, target_solver: tp.Optional[str] = None, target_initial_guess: tp.Optional[tp.Array] = None, objectives: tp.Optional[tp.MaybeIterable[tp.Union[tp.Callable, str]]] = None, constraints: tp.Optional[tp.MaybeIterable[tp.Callable]] = None, sector_mapper: tp.Optional[dict] = None, sector_lower: tp.Optional[dict] = None, sector_upper: tp.Optional[dict] = None, discrete_allocation: tp.Optional[bool] = None, allocation_method: tp.Optional[str] = None, silence_warnings: tp.Optional[bool] = None, ignore_opt_errors: tp.Optional[bool] = None, ignore_errors: tp.Optional[bool] = None, **kwargs, ) -> tp.Dict[str, float]: """Get allocation using PyPortfolioOpt. First, it resolves the optimizer using `resolve_pypfopt_optimizer`. Depending upon which arguments it takes, it may further resolve expected returns, covariance matrix, etc. Then, it adds objectives and constraints to the optimizer instance, calls the target metric, extracts the weights, and finally, converts the weights to an integer allocation (if requested). To specify the optimizer, use `optimizer` (see `resolve_pypfopt_optimizer`). To specify the expected returns, use `expected_returns` (see `resolve_pypfopt_expected_returns`). To specify the covariance matrix, use `cov_matrix` (see `resolve_pypfopt_cov_matrix`). All other keyword arguments in `**kwargs` are used by `resolve_pypfopt_func_call`. Each objective can be a function, an attribute of `pypfopt.objective_functions`, or an iterable of such. Each constraint can be a function or an interable of such. The target can be an attribute of the optimizer, or a stand-alone function. If `target_is_convex` is True, the function is added as a convex function. Otherwise, the function is added as a non-convex function. The keyword arguments `weights_sum_to_one` and those starting with `target` are passed `pypfopt.base_optimizer.BaseConvexOptimizer.convex_objective` and `pypfopt.base_optimizer.BaseConvexOptimizer.nonconvex_objective` respectively. Set `ignore_opt_errors` to True to ignore any target optimization errors. Set `ignore_errors` to True to ignore any errors, even those caused by the user. If `discrete_allocation` is True, resolves `pypfopt.discrete_allocation.DiscreteAllocation` and calls `allocation_method` as an attribute of the allocation object. Any function is resolved using `resolve_pypfopt_func_call`. For defaults, see `pypfopt` under `vectorbtpro._settings.pfopt`. Usage: * Using mean historical returns, Ledoit-Wolf covariance matrix with constant variance, and efficient frontier: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull(["MSFT", "AMZN", "KO", "MA"]) ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> vbt.pypfopt_optimize(prices=data.get("Close")) {'MSFT': 0.13324, 'AMZN': 0.10016, 'KO': 0.03229, 'MA': 0.73431} ``` * EMA historical returns and sample covariance: ```pycon >>> vbt.pypfopt_optimize( ... prices=data.get("Close"), ... expected_returns="ema_historical_return", ... cov_matrix="sample_cov" ... ) {'MSFT': 0.08984, 'AMZN': 0.0, 'KO': 0.91016, 'MA': 0.0} ``` * EMA historical returns, efficient Conditional Value at Risk, and other parameters automatically passed to their respective functions. Optimized towards lowest CVaR: ```pycon >>> vbt.pypfopt_optimize( ... prices=data.get("Close"), ... expected_returns="ema_historical_return", ... optimizer="efficient_cvar", ... beta=0.9, ... weight_bounds=(-1, 1), ... target="min_cvar" ... ) {'MSFT': 0.14779, 'AMZN': 0.07224, 'KO': 0.77552, 'MA': 0.00445} ``` * Adding custom objectives: ```pycon >>> vbt.pypfopt_optimize( ... prices=data.get("Close"), ... objectives=["L2_reg"], ... gamma=0.1, ... target="min_volatility" ... ) {'MSFT': 0.22228, 'AMZN': 0.15685, 'KO': 0.28712, 'MA': 0.33375} ``` * Adding custom constraints: ```pycon >>> vbt.pypfopt_optimize( ... prices=data.get("Close"), ... constraints=[lambda w: w[data.symbols.index("MSFT")] <= 0.1] ... ) {'MSFT': 0.1, 'AMZN': 0.10676, 'KO': 0.04341, 'MA': 0.74982} ``` * Optimizing towards a custom convex objective (to add a non-convex objective, set `target_is_convex` to False): ```pycon >>> import cvxpy as cp >>> def logarithmic_barrier_objective(w, cov_matrix, k=0.1): ... log_sum = cp.sum(cp.log(w)) ... var = cp.quad_form(w, cov_matrix) ... return var - k * log_sum >>> pypfopt_optimize( ... prices=data.get("Close"), ... target=logarithmic_barrier_objective ... ) {'MSFT': 0.24595, 'AMZN': 0.23047, 'KO': 0.25862, 'MA': 0.26496} ``` """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("pypfopt") from pypfopt.exceptions import OptimizationError from cvxpy.error import SolverError from vectorbtpro._settings import settings pypfopt_cfg = dict(settings["pfopt"]["pypfopt"]) def _resolve_setting(k, v): setting = pypfopt_cfg.pop(k) if v is None: return setting return v target = _resolve_setting("target", target) target_is_convex = _resolve_setting("target_is_convex", target_is_convex) weights_sum_to_one = _resolve_setting("weights_sum_to_one", weights_sum_to_one) target_constraints = _resolve_setting("target_constraints", target_constraints) target_solver = _resolve_setting("target_solver", target_solver) target_initial_guess = _resolve_setting("target_initial_guess", target_initial_guess) objectives = _resolve_setting("objectives", objectives) constraints = _resolve_setting("constraints", constraints) sector_mapper = _resolve_setting("sector_mapper", sector_mapper) sector_lower = _resolve_setting("sector_lower", sector_lower) sector_upper = _resolve_setting("sector_upper", sector_upper) discrete_allocation = _resolve_setting("discrete_allocation", discrete_allocation) allocation_method = _resolve_setting("allocation_method", allocation_method) silence_warnings = _resolve_setting("silence_warnings", silence_warnings) ignore_opt_errors = _resolve_setting("ignore_opt_errors", ignore_opt_errors) ignore_errors = _resolve_setting("ignore_errors", ignore_errors) kwargs = merge_dicts(pypfopt_cfg, kwargs) if "cache" not in kwargs: kwargs["cache"] = {} if "used_arg_names" not in kwargs: kwargs["used_arg_names"] = set() try: with WarningsFiltered(entries="ignore" if silence_warnings else None): optimizer = kwargs["optimizer"] = resolve_pypfopt_optimizer(**kwargs) if objectives is not None: if not checks.is_iterable(objectives) or isinstance(objectives, str): objectives = [objectives] for objective in objectives: if isinstance(objective, str): import pypfopt.objective_functions objective = getattr(pypfopt.objective_functions, objective) objective_kwargs = resolve_pypfopt_func_kwargs(objective, **kwargs) optimizer.add_objective(objective, **objective_kwargs) if constraints is not None: if not checks.is_iterable(constraints): constraints = [constraints] for constraint in constraints: optimizer.add_constraint(constraint) if sector_mapper is not None: if sector_lower is None: sector_lower = {} if sector_upper is None: sector_upper = {} optimizer.add_sector_constraints(sector_mapper, sector_lower, sector_upper) try: if isinstance(target, str): resolve_pypfopt_func_call(getattr(optimizer, target), **kwargs) else: if target_is_convex: optimizer.convex_objective( target, weights_sum_to_one=weights_sum_to_one, **resolve_pypfopt_func_kwargs(target, **kwargs), ) else: optimizer.nonconvex_objective( target, objective_args=tuple(resolve_pypfopt_func_kwargs(target, **kwargs).values()), weights_sum_to_one=weights_sum_to_one, constraints=target_constraints, solver=target_solver, initial_guess=target_initial_guess, ) except (OptimizationError, SolverError, ValueError) as e: if isinstance(e, ValueError) and "expected return exceeding the risk-free rate" not in str(e): raise e if ignore_opt_errors: warn(str(e)) return {} raise e weights = kwargs["weights"] = resolve_pypfopt_func_call(optimizer.clean_weights, **kwargs) if discrete_allocation: from pypfopt.discrete_allocation import DiscreteAllocation allocator = resolve_pypfopt_func_call(DiscreteAllocation, **kwargs) return resolve_pypfopt_func_call(getattr(allocator, allocation_method), **kwargs)[0] passed_arg_names = set(kwargs.keys()) passed_arg_names.remove("cache") passed_arg_names.remove("used_arg_names") passed_arg_names.remove("optimizer") passed_arg_names.remove("weights") unused_arg_names = passed_arg_names.difference(kwargs["used_arg_names"]) if len(unused_arg_names) > 0: warn(f"Some arguments were not used: {unused_arg_names}") if not discrete_allocation: weights = {k: 1 if v >= 1 else v for k, v in weights.items()} return dict(weights) except Exception as e: if ignore_errors: return {} raise e # ############# Riskfolio-Lib ############# # def prepare_returns( returns: tp.AnyArray2d, nan_to_zero: bool = True, dropna_rows: bool = True, dropna_cols: bool = True, dropna_any: bool = True, ) -> tp.Frame: """Prepare returns.""" returns = to_pd_array(returns) if not isinstance(returns, pd.DataFrame): raise ValueError("Returns must be a two-dimensional array") if returns.size == 0: return returns if nan_to_zero or dropna_rows or dropna_cols or dropna_any: returns = returns.replace([np.inf, -np.inf], np.nan) if nan_to_zero: returns = returns.fillna(0.0) if dropna_rows or dropna_cols: if nan_to_zero: valid_mask = returns != 0 else: valid_mask = ~returns.isnull() if dropna_rows: if nan_to_zero or not dropna_any: returns = returns.loc[valid_mask.any(axis=1)] if returns.size == 0: return returns if dropna_cols: returns = returns.loc[:, valid_mask.any(axis=0)] if returns.size == 0: return returns if not nan_to_zero and dropna_any: returns = returns.dropna() return returns def resolve_riskfolio_func_kwargs( riskfolio_func: tp.Callable, unused_arg_names: tp.Optional[tp.Set[str]] = None, func_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Kwargs: """Select keyword arguments belonging to `riskfolio_func`.""" func_arg_names = get_func_arg_names(riskfolio_func) matched_kwargs = dict() for k, v in kwargs.items(): if k in func_arg_names: matched_kwargs[k] = v if unused_arg_names is not None: if k in unused_arg_names: unused_arg_names.remove(k) if func_kwargs is not None: return merge_dicts( select_pfopt_func_kwargs(riskfolio_func, matched_kwargs), select_pfopt_func_kwargs(riskfolio_func, pfopt_func_dict(func_kwargs)), ) return select_pfopt_func_kwargs(riskfolio_func, matched_kwargs) def resolve_asset_classes( asset_classes: tp.Union[None, tp.Frame, tp.Sequence], columns: tp.Index, col_indices: tp.Optional[tp.Sequence[int]] = None, ) -> tp.Frame: """Resolve asset classes for Riskfolio-Lib. Supports the following formats: * None: Takes columns where the bottom-most level is assumed to be assets * Index: Each level in the index must be a different asset class set * Nested dict: Each sub-dict must be a different asset class set * Sequence of strings or ints: Matches them against level names in the columns. If the columns have a single level, or some level names were not found, uses the sequence directly as one class asset set named 'Class'. * Sequence of dicts: Each dict becomes a row in the new DataFrame * DataFrame where the first column is the asset list and the next columns are the different asset’s classes sets (this is the target format accepted by Riskfolio-Lib). See an example [here](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_constraints). !!! note If `asset_classes` is neither None nor a DataFrame, the bottom-most level in `columns` gets renamed to 'Assets' and becomes the first column of the new DataFrame.""" if asset_classes is None: asset_classes = columns.to_frame().reset_index(drop=True).iloc[:, ::-1] asset_classes = asset_classes.rename(columns={asset_classes.columns[0]: "Assets"}) if not isinstance(asset_classes, pd.DataFrame): if isinstance(asset_classes, dict): asset_classes = pd.DataFrame(asset_classes) elif isinstance(asset_classes, pd.Index): asset_classes = asset_classes.to_frame().reset_index(drop=True) elif checks.is_sequence(asset_classes) and isinstance(asset_classes[0], int): asset_classes = select_levels(columns, asset_classes).to_frame().reset_index(drop=True) elif checks.is_sequence(asset_classes) and isinstance(asset_classes[0], str): if isinstance(columns, pd.MultiIndex) and set(asset_classes) <= set(columns.names): asset_classes = select_levels(columns, asset_classes).to_frame().reset_index(drop=True) else: asset_classes = pd.Index(asset_classes, name="Class").to_frame().reset_index(drop=True) else: asset_classes = pd.DataFrame.from_records(asset_classes) if isinstance(columns, pd.MultiIndex): assets = columns.get_level_values(-1) else: assets = columns if col_indices is not None and len(col_indices) > 0: asset_classes = asset_classes.iloc[col_indices] asset_classes.insert(loc=0, column="Assets", value=assets) return asset_classes def resolve_assets_constraints(constraints: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame: """Resolve asset constraints for Riskfolio-Lib. Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_constraints), also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame. Dicts don't have to specify all column names, the function will autofill any missing elements/columns.""" if not isinstance(constraints, pd.DataFrame): if isinstance(constraints, dict): constraints = pd.DataFrame(constraints) else: constraints = pd.DataFrame.from_records(constraints) constraints.columns = constraints.columns.str.title() new_constraints = pd.DataFrame( columns=[ "Disabled", "Type", "Set", "Position", "Sign", "Weight", "Type Relative", "Relative Set", "Relative", "Factor", ], dtype=object, ) for c in new_constraints.columns: if c in constraints.columns: new_constraints[c] = constraints[c] new_constraints.fillna("", inplace=True) new_constraints["Disabled"].replace("", False, inplace=True) constraints = new_constraints return constraints def resolve_factors_constraints(constraints: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame: """Resolve factors constraints for Riskfolio-Lib. Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_constraints), also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame. Dicts don't have to specify all column names, the function will autofill any missing elements/columns.""" if not isinstance(constraints, pd.DataFrame): if isinstance(constraints, dict): constraints = pd.DataFrame(constraints) else: constraints = pd.DataFrame.from_records(constraints) constraints.columns = constraints.columns.str.title() new_constraints = pd.DataFrame( columns=[ "Disabled", "Factor", "Sign", "Value", "Relative Factor", ], dtype=object, ) for c in new_constraints.columns: if c in constraints.columns: new_constraints[c] = constraints[c] new_constraints.fillna("", inplace=True) new_constraints["Disabled"].replace("", False, inplace=True) constraints = new_constraints return constraints def resolve_assets_views(views: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame: """Resolve asset views for Riskfolio-Lib. Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_views), also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame. Dicts don't have to specify all column names, the function will autofill any missing elements/columns.""" if not isinstance(views, pd.DataFrame): if isinstance(views, dict): views = pd.DataFrame(views) else: views = pd.DataFrame.from_records(views) views.columns = views.columns.str.title() new_views = pd.DataFrame( columns=[ "Disabled", "Type", "Set", "Position", "Sign", "Return", "Type Relative", "Relative Set", "Relative", ], dtype=object, ) for c in new_views.columns: if c in views.columns: new_views[c] = views[c] new_views.fillna("", inplace=True) new_views["Disabled"].replace("", False, inplace=True) views = new_views return views def resolve_factors_views(views: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame: """Resolve factors views for Riskfolio-Lib. Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_views), also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame. Dicts don't have to specify all column names, the function will autofill any missing elements/columns.""" if not isinstance(views, pd.DataFrame): if isinstance(views, dict): views = pd.DataFrame(views) else: views = pd.DataFrame.from_records(views) views.columns = views.columns.str.title() new_views = pd.DataFrame( columns=[ "Disabled", "Factor", "Sign", "Value", "Relative Factor", ], dtype=object, ) for c in new_views.columns: if c in views.columns: new_views[c] = views[c] new_views.fillna("", inplace=True) new_views["Disabled"].replace("", False, inplace=True) views = new_views return views def resolve_hrp_constraints(constraints: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame: """Resolve HRP constraints for Riskfolio-Lib. Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.hrp_constraints), also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame. Dicts don't have to specify all column names, the function will autofill any missing elements/columns.""" if not isinstance(constraints, pd.DataFrame): if isinstance(constraints, dict): constraints = pd.DataFrame(constraints) else: constraints = pd.DataFrame.from_records(constraints) constraints.columns = constraints.columns.str.title() new_constraints = pd.DataFrame( columns=[ "Disabled", "Type", "Set", "Position", "Sign", "Weight", ], dtype=object, ) for c in new_constraints.columns: if c in constraints.columns: new_constraints[c] = constraints[c] new_constraints.fillna("", inplace=True) new_constraints["Disabled"].replace("", False, inplace=True) constraints = new_constraints return constraints def riskfolio_optimize( returns: tp.AnyArray2d, nan_to_zero: tp.Optional[bool] = None, dropna_rows: tp.Optional[bool] = None, dropna_cols: tp.Optional[bool] = None, dropna_any: tp.Optional[bool] = None, factors: tp.Optional[tp.AnyArray2d] = None, port: tp.Optional[RPortfolioT] = None, port_cls: tp.Union[None, str, tp.Type] = None, opt_method: tp.Union[None, str, tp.Callable] = None, stats_methods: tp.Optional[tp.Sequence[str]] = None, model: tp.Optional[str] = None, asset_classes: tp.Union[None, tp.Frame, tp.Sequence] = None, constraints_method: tp.Optional[str] = None, constraints: tp.Union[None, tp.Frame, tp.Sequence] = None, views_method: tp.Optional[str] = None, views: tp.Union[None, tp.Frame, tp.Sequence] = None, solvers: tp.Optional[tp.Sequence[str]] = None, sol_params: tp.KwargsLike = None, freq: tp.Optional[tp.FrequencyLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, pre_opt: tp.Optional[bool] = None, pre_opt_kwargs: tp.KwargsLike = None, pre_opt_as_w: tp.Optional[bool] = None, func_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, return_port: tp.Optional[bool] = None, ignore_errors: tp.Optional[bool] = None, **kwargs, ) -> tp.Union[tp.Dict[str, float], tp.Tuple[tp.Dict[str, float], RPortfolioT]]: """Get allocation using Riskfolio-Lib. Args: returns (array_like): A dataframe that contains the returns of the assets. nan_to_zero (bool): Whether to convert NaN values to zero. dropna_rows (bool): Whether to drop rows with all NaN/zero values. Gets applied only if `nan_to_zero` is True or `dropna_any` is False. dropna_cols (bool): Whether to drop columns with all NaN/zero values. dropna_any (bool): Whether to drop any NaN values. Gets applied only if `nan_to_zero` is False. factors (array_like): A dataframe that contains the factors. port (Portfolio or HCPortfolio): Already initialized portfolio. port_cls (str or type): Portfolio class. Supports the following values: * None: Uses `Portfolio` * 'hc' or 'hcportfolio' (case-insensitive): Uses `HCPortfolio` * Other string: Uses attribute of `riskfolio` * Class: Uses a custom class opt_method (str or callable): Optimization method. Supports the following values: * None or 'optimization': Uses `port.optimization` (where `port` is a portfolio instance) * 'wc' or 'wc_optimization': Uses `port.wc_optimization` * 'rp' or 'rp_optimization': Uses `port.rp_optimization` * 'rrp' or 'rrp_optimization': Uses `port.rrp_optimization` * 'owa' or 'owa_optimization': Uses `port.owa_optimization` * String: Uses attribute of `port` * Callable: Uses a custom optimization function stats_methods (str or sequence of str): Sequence of stats methods to call before optimization. If None, tries to automatically populate the sequence using `opt_method` and `model`. For example, calls `port.assets_stats` if `model="Classic"` is used. Also, if `func_kwargs` is not empty, adds all functions whose name ends with '_stats'. model (str): The model used to optimize the portfolio. asset_classes (any): Asset classes matrix. See `resolve_asset_classes` for possible formats. constraints_method (str): Constraints method. Supports the following values: * 'assets' or 'assets_constraints': [assets constraints](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_constraints) * 'factors' or 'factors_constraints': [factors constraints](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_constraints) * 'hrp' or 'hrp_constraints': [HRP constraints](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.hrp_constraints) If None and the class `Portfolio` is used, will use factors constraints if `factors_stats` is used, otherwise assets constraints. If the class `HCPortfolio` is used, will use HRP constraints. constraints (any): Constraints matrix. See `resolve_assets_constraints` for possible formats of assets constraints, `resolve_factors_constraints` for possible formats of factors constraints, and `resolve_hrp_constraints` for possible formats of HRP constraints. views_method (str): Views method. Supports the following values: * 'assets' or 'assets_views': [assets views](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_views) * 'factors' or 'factors_views': [factors views](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_views) If None, will use factors views if `blfactors_stats` is used, otherwise assets views. views (any): Views matrix. See `resolve_assets_views` for possible formats of assets views and `resolve_factors_views` for possible formats of factors views. solvers (list of str): Solvers. sol_params (dict): Solver parameters. freq (frequency_like): Frequency to be used to compute the annualization factor. Make sure to provide it when using views. year_freq (frequency_like): Year frequency to be used to compute the annualization factor. Make sure to provide it when using views. pre_opt (bool): Whether to pre-optimize the portfolio with `pre_opt_kwargs`. pre_opt_kwargs (dict): Call `riskfolio_optimize` with these keyword arguments and use the returned portfolio for further optimization. pre_opt_as_w (bool): Whether to use the weights as `w` from the pre-optimization step. func_kwargs (dict): Further keyword arguments by function. Can be used to override any arguments from `kwargs` matched with the function, or to add more arguments. Will be wrapped with `pfopt_func_dict` and passed to `select_pfopt_func_kwargs` when calling each Riskfolio-Lib's function. silence_warnings (bool): Whether to silence all warnings. return_port (bool): Whether to also return the portfolio. ignore_errors (bool): Whether to ignore any errors, even those caused by the user. **kwargs: Keyword arguments that will be passed to any Riskfolio-Lib's function that needs them (i.e., lists any of them in its signature). For defaults, see `riskfolio` under `vectorbtpro._settings.pfopt`. Usage: * Classic Mean Risk Optimization: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull(["MSFT", "AMZN", "KO", "MA"]) >>> returns = data.close.vbt.to_returns() ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> vbt.riskfolio_optimize( ... returns, ... method_mu='hist', method_cov='hist', d=0.94, # assets_stats ... model='Classic', rm='MV', obj='Sharpe', hist=True, rf=0, l=0 # optimization ... ) {'MSFT': 0.26297126323056036, 'AMZN': 0.13984467450137006, 'KO': 0.35870315943426767, 'MA': 0.238480902833802} ``` * The same by splitting arguments: ```pycon >>> vbt.riskfolio_optimize( ... returns, ... func_kwargs=dict( ... assets_stats=dict(method_mu='hist', method_cov='hist', d=0.94), ... optimization=dict(model='Classic', rm='MV', obj='Sharpe', hist=True, rf=0, l=0) ... ) ... ) {'MSFT': 0.26297126323056036, 'AMZN': 0.13984467450137006, 'KO': 0.35870315943426767, 'MA': 0.238480902833802} ``` * Asset constraints: ```pycon >>> vbt.riskfolio_optimize( ... returns, ... constraints=[ ... { ... "Type": "Assets", ... "Position": "MSFT", ... "Sign": "<=", ... "Weight": 0.01 ... } ... ] ... ) {'MSFT': 0.009999990814976588, 'AMZN': 0.19788481506569947, 'KO': 0.4553600308839969, 'MA': 0.336755163235327} ``` * Asset class constraints: ```pycon >>> vbt.riskfolio_optimize( ... returns, ... asset_classes=["C1", "C1", "C2", "C2"], ... constraints=[ ... { ... "Type": "Classes", ... "Set": "Class", ... "Position": "C1", ... "Sign": "<=", ... "Weight": 0.1 ... } ... ] ... ) {'MSFT': 0.03501297245802569, 'AMZN': 0.06498702655063979, 'KO': 0.4756624658301967, 'MA': 0.4243375351611379} ``` * Hierarchical Risk Parity (HRP) Portfolio Optimization: ```pycon >>> vbt.riskfolio_optimize( ... returns, ... port_cls="HCPortfolio", ... model='HRP', ... codependence='pearson', ... rm='MV', ... rf=0, ... linkage='single', ... max_k=10, ... leaf_order=True ... ) {'MSFT': 0.19091632057853536, 'AMZN': 0.11069893826556164, 'KO': 0.28589872132122485, 'MA': 0.41248601983467814} ``` """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("riskfolio") import riskfolio as rp from vectorbtpro._settings import settings riskfolio_cfg = dict(settings["pfopt"]["riskfolio"]) def _resolve_setting(k, v): setting = riskfolio_cfg.pop(k) if v is None: return setting return v nan_to_zero = _resolve_setting("nan_to_zero", nan_to_zero) dropna_rows = _resolve_setting("dropna_rows", dropna_rows) dropna_cols = _resolve_setting("dropna_cols", dropna_cols) dropna_any = _resolve_setting("dropna_any", dropna_any) factors = _resolve_setting("factors", factors) port = _resolve_setting("port", port) port_cls = _resolve_setting("port_cls", port_cls) opt_method = _resolve_setting("opt_method", opt_method) stats_methods = _resolve_setting("stats_methods", stats_methods) model = _resolve_setting("model", model) asset_classes = _resolve_setting("asset_classes", asset_classes) constraints_method = _resolve_setting("constraints_method", constraints_method) constraints = _resolve_setting("constraints", constraints) views_method = _resolve_setting("views_method", views_method) views = _resolve_setting("views", views) solvers = _resolve_setting("solvers", solvers) sol_params = _resolve_setting("sol_params", sol_params) freq = _resolve_setting("freq", freq) year_freq = _resolve_setting("year_freq", year_freq) pre_opt = _resolve_setting("pre_opt", pre_opt) pre_opt_kwargs = merge_dicts(riskfolio_cfg.pop("pre_opt_kwargs"), pre_opt_kwargs) pre_opt_as_w = _resolve_setting("pre_opt_as_w", pre_opt_as_w) func_kwargs = merge_dicts(riskfolio_cfg.pop("func_kwargs"), func_kwargs) silence_warnings = _resolve_setting("silence_warnings", silence_warnings) return_port = _resolve_setting("return_port", return_port) ignore_errors = _resolve_setting("ignore_errors", ignore_errors) kwargs = merge_dicts(riskfolio_cfg, kwargs) if pre_opt_kwargs is None: pre_opt_kwargs = {} if func_kwargs is None: func_kwargs = {} func_kwargs = pfopt_func_dict(func_kwargs) unused_arg_names = set(kwargs.keys()) try: with WarningsFiltered(entries="ignore" if silence_warnings else None): # Prepare returns new_returns = prepare_returns( returns, nan_to_zero=nan_to_zero, dropna_rows=dropna_rows, dropna_cols=dropna_cols, dropna_any=dropna_any, ) col_indices = [i for i, c in enumerate(returns.columns) if c in new_returns.columns] returns = new_returns if returns.size == 0: return {} # Pre-optimize if pre_opt: w, port = riskfolio_optimize(returns, port=port, return_port=True, **pre_opt_kwargs) if pre_opt_as_w: w = pd.DataFrame.from_records([w]).T.rename(columns={0: "weights"}) kwargs["w"] = w unused_arg_names.add("w") # Build portfolio if port_cls is None: port_cls = rp.Portfolio elif isinstance(port_cls, str) and port_cls.lower() in ("hc", "hcportfolio"): port_cls = rp.HCPortfolio elif isinstance(port_cls, str): port_cls = getattr(rp, port_cls) else: port_cls = port_cls matched_kwargs = resolve_riskfolio_func_kwargs( port_cls, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) if port is None: port = port_cls(returns, **matched_kwargs) else: for k, v in matched_kwargs.items(): setattr(port, k, v) if solvers is not None: port.solvers = list(solvers) if sol_params is not None: port.sol_params = dict(sol_params) if factors is not None: factors = to_pd_array(factors).dropna() port.factors = factors # Resolve optimization and stats methods if opt_method is None: if len(func_kwargs) > 0: for name_or_func in func_kwargs: if isinstance(name_or_func, str): if name_or_func.endswith("optimization"): if opt_method is not None: raise ValueError("Function keyword arguments list multiple optimization methods") opt_method = name_or_func if opt_method is None: opt_method = "optimization" if stats_methods is None: if len(func_kwargs) > 0: for name_or_func in func_kwargs: if isinstance(name_or_func, str): if name_or_func.endswith("_stats"): if stats_methods is None: stats_methods = [] stats_methods.append(name_or_func) if isinstance(port, rp.Portfolio): if isinstance(opt_method, str) and opt_method.lower() == "optimization": opt_func = port.optimization if model is None: opt_func_kwargs = select_pfopt_func_kwargs(opt_func, func_kwargs) model = opt_func_kwargs.get("model", "Classic") if model.lower() == "classic": model = "Classic" if stats_methods is None: stats_methods = ["assets_stats"] elif model.lower() == "fm": model = "FM" if stats_methods is None: stats_methods = ["assets_stats", "factors_stats"] elif model.lower() == "bl": model = "BL" if stats_methods is None: stats_methods = ["assets_stats", "blacklitterman_stats"] elif model.lower() in ("bl_fm", "blfm"): model = "BL_FM" if stats_methods is None: stats_methods = ["assets_stats", "factors_stats", "blfactors_stats"] elif isinstance(opt_method, str) and opt_method.lower() in ("wc", "wc_optimization"): opt_func = port.wc_optimization if stats_methods is None: stats_methods = ["assets_stats", "wc_stats"] elif isinstance(opt_method, str) and opt_method.lower() in ("rp", "rp_optimization"): opt_func = port.rp_optimization if model is None: opt_func_kwargs = select_pfopt_func_kwargs(opt_func, func_kwargs) model = opt_func_kwargs.get("model", "Classic") if model.lower() == "classic": model = "Classic" if stats_methods is None: stats_methods = ["assets_stats"] elif model.lower() == "fm": model = "FM" if stats_methods is None: stats_methods = ["assets_stats", "factors_stats"] elif isinstance(opt_method, str) and opt_method.lower() in ("rrp", "rrp_optimization"): opt_func = port.rrp_optimization if model is None: opt_func_kwargs = select_pfopt_func_kwargs(opt_func, func_kwargs) model = opt_func_kwargs.get("model", "Classic") if model.lower() == "classic": model = "Classic" if stats_methods is None: stats_methods = ["assets_stats"] elif model.lower() == "fm": model = "FM" if stats_methods is None: stats_methods = ["assets_stats", "factors_stats"] elif isinstance(opt_method, str) and opt_method.lower() in ("owa", "owa_optimization"): opt_func = port.owa_optimization if stats_methods is None: stats_methods = ["assets_stats"] elif isinstance(opt_method, str): opt_func = getattr(port, opt_method) else: opt_func = opt_method else: if isinstance(opt_method, str): opt_func = getattr(port, opt_method) else: opt_func = opt_method if model is not None: kwargs["model"] = model unused_arg_names.add("model") if stats_methods is None: stats_methods = [] # Apply constraints if constraints is not None: if constraints_method is None: if isinstance(port, rp.Portfolio): if "factors_stats" in stats_methods: constraints_method = "factors" else: constraints_method = "assets" elif isinstance(port, rp.HCPortfolio): constraints_method = "hrp" else: raise ValueError("Constraints method is required") if constraints_method.lower() in ("assets", "assets_constraints"): asset_classes = resolve_asset_classes(asset_classes, returns.columns, col_indices) kwargs["asset_classes"] = asset_classes unused_arg_names.add("asset_classes") constraints = resolve_assets_constraints(constraints) kwargs["constraints"] = constraints unused_arg_names.add("constraints") matched_kwargs = resolve_riskfolio_func_kwargs( rp.assets_constraints, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) port.ainequality, port.binequality = warn_stdout(rp.assets_constraints)(**matched_kwargs) elif constraints_method.lower() in ("factors", "factors_constraints"): if "loadings" not in kwargs: matched_kwargs = resolve_riskfolio_func_kwargs( rp.loadings_matrix, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) if "X" not in matched_kwargs: matched_kwargs["X"] = port.factors if "Y" not in matched_kwargs: matched_kwargs["Y"] = port.returns loadings = warn_stdout(rp.loadings_matrix)(**matched_kwargs) kwargs["loadings"] = loadings unused_arg_names.add("loadings") constraints = resolve_factors_constraints(constraints) kwargs["constraints"] = constraints unused_arg_names.add("constraints") matched_kwargs = resolve_riskfolio_func_kwargs( rp.factors_constraints, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) port.ainequality, port.binequality = warn_stdout(rp.factors_constraints)(**matched_kwargs) elif constraints_method.lower() in ("hrp", "hrp_constraints"): asset_classes = resolve_asset_classes(asset_classes, returns.columns, col_indices) kwargs["asset_classes"] = asset_classes unused_arg_names.add("asset_classes") constraints = resolve_hrp_constraints(constraints) kwargs["constraints"] = constraints unused_arg_names.add("constraints") matched_kwargs = resolve_riskfolio_func_kwargs( rp.hrp_constraints, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) port.w_max, port.w_min = warn_stdout(rp.hrp_constraints)(**matched_kwargs) else: raise ValueError(f"Constraints method '{constraints_method}' is not supported") # Resolve views if views is not None: if views_method is None: if "blfactors_stats" in stats_methods: views_method = "factors" else: views_method = "assets" if views_method.lower() in ("assets", "assets_views"): asset_classes = resolve_asset_classes(asset_classes, returns.columns, col_indices) kwargs["asset_classes"] = asset_classes unused_arg_names.add("asset_classes") views = resolve_assets_views(views) kwargs["views"] = views unused_arg_names.add("views") matched_kwargs = resolve_riskfolio_func_kwargs( rp.assets_views, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) P, Q = warn_stdout(rp.assets_views)(**matched_kwargs) ann_factor = ReturnsAccessor.get_ann_factor(year_freq=year_freq, freq=freq) if ann_factor is not None: Q /= ann_factor else: warn(f"Set frequency and year frequency to adjust expected returns") kwargs["P"] = P unused_arg_names.add("P") kwargs["Q"] = Q unused_arg_names.add("Q") elif views_method.lower() in ("factors", "factors_views"): if "loadings" not in kwargs: matched_kwargs = resolve_riskfolio_func_kwargs( rp.loadings_matrix, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) if "X" not in matched_kwargs: matched_kwargs["X"] = port.factors if "Y" not in matched_kwargs: matched_kwargs["Y"] = port.returns loadings = warn_stdout(rp.loadings_matrix)(**matched_kwargs) kwargs["loadings"] = loadings unused_arg_names.add("loadings") if "B" not in kwargs: kwargs["B"] = kwargs["loadings"] unused_arg_names.add("B") views = resolve_factors_views(views) kwargs["views"] = views unused_arg_names.add("views") matched_kwargs = resolve_riskfolio_func_kwargs( rp.factors_views, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) P_f, Q_f = warn_stdout(rp.factors_views)(**matched_kwargs) ann_factor = ReturnsAccessor.get_ann_factor(year_freq=year_freq, freq=freq) if ann_factor is not None: Q_f /= ann_factor else: warn(f"Set frequency and year frequency to adjust expected returns") kwargs["P_f"] = P_f unused_arg_names.add("P_f") kwargs["Q_f"] = Q_f unused_arg_names.add("Q_f") else: raise ValueError(f"Views method '{constraints_method}' is not supported") # Run stats for stats_method in stats_methods: stats_func = getattr(port, stats_method) matched_kwargs = resolve_riskfolio_func_kwargs( stats_func, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) warn_stdout(stats_func)(**matched_kwargs) # Run optimization matched_kwargs = resolve_riskfolio_func_kwargs( opt_func, unused_arg_names=unused_arg_names, func_kwargs=func_kwargs, **kwargs, ) weights = warn_stdout(opt_func)(**matched_kwargs) # Post-process weights if len(unused_arg_names) > 0: warn(f"Some arguments were not used: {unused_arg_names}") if weights is None: weights = {} if isinstance(weights, pd.DataFrame): if "weights" not in weights.columns: raise ValueError("Weights column wasn't returned") weights = weights["weights"] if return_port: return dict(weights), port return dict(weights) except Exception as e: if ignore_errors: return {} raise e # ############# PortfolioOptimizer ############# # PortfolioOptimizerT = tp.TypeVar("PortfolioOptimizerT", bound="PortfolioOptimizer") class PortfolioOptimizer(Analyzable): """Class that exposes methods for generating allocations.""" @hybrid_method def row_stack( cls_or_self: tp.MaybeType[PortfolioOptimizerT], *objs: tp.MaybeTuple[PortfolioOptimizerT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> PortfolioOptimizerT: """Stack multiple `PortfolioOptimizer` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, PortfolioOptimizer): raise TypeError("Each object to be merged must be an instance of PortfolioOptimizer") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack( *[obj.wrapper for obj in objs], stack_columns=False, **wrapper_kwargs, ) if "alloc_records" not in kwargs: alloc_records_type = None for obj in objs: if alloc_records_type is None: alloc_records_type = type(obj.alloc_records) elif not isinstance(obj.alloc_records, alloc_records_type): raise TypeError("Objects to be merged must have the same type for alloc_records") kwargs["alloc_records"] = alloc_records_type.row_stack( *[obj.alloc_records for obj in objs], wrapper_kwargs=wrapper_kwargs, ) if "allocations" not in kwargs: record_indices = type(kwargs["alloc_records"]).get_row_stack_record_indices( *[obj.alloc_records for obj in objs], wrapper=kwargs["alloc_records"].wrapper, ) kwargs["allocations"] = row_stack_arrays([obj._allocations for obj in objs])[record_indices] kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[PortfolioOptimizerT], *objs: tp.MaybeTuple[PortfolioOptimizerT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> PortfolioOptimizerT: """Stack multiple `PortfolioOptimizer` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, PortfolioOptimizer): raise TypeError("Each object to be merged must be an instance of PortfolioOptimizer") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], union_index=False, **wrapper_kwargs, ) if "alloc_records" not in kwargs: alloc_records_type = None for obj in objs: if alloc_records_type is None: alloc_records_type = type(obj.alloc_records) elif not isinstance(obj.alloc_records, alloc_records_type): raise TypeError("Objects to be merged must have the same type for alloc_records") kwargs["alloc_records"] = alloc_records_type.column_stack( *[obj.alloc_records for obj in objs], wrapper_kwargs=wrapper_kwargs, ) if "allocations" not in kwargs: record_indices = type(kwargs["alloc_records"]).get_column_stack_record_indices( *[obj.alloc_records for obj in objs], wrapper=kwargs["alloc_records"].wrapper, ) kwargs["allocations"] = row_stack_arrays([obj._allocations for obj in objs])[record_indices] kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, alloc_records: tp.Union[AllocRanges, AllocPoints], allocations: tp.Array2d, **kwargs, ) -> None: Analyzable.__init__( self, wrapper, alloc_records=alloc_records, allocations=allocations, **kwargs, ) self._alloc_records = alloc_records self._allocations = allocations # Only slices of rows can be selected self._range_only_select = True def indexing_func( self: PortfolioOptimizerT, *args, wrapper_meta: tp.DictLike = None, alloc_wrapper_meta: tp.DictLike = None, alloc_records_meta: tp.DictLike = None, **kwargs, ) -> PortfolioOptimizerT: """Perform indexing on `PortfolioOptimizer`.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) if alloc_records_meta is None: alloc_records_meta = self.alloc_records.indexing_func_meta( *args, wrapper_meta=alloc_wrapper_meta, **kwargs, ) new_alloc_records = self.alloc_records.indexing_func( *args, records_meta=alloc_records_meta, **kwargs, ) new_allocations = to_2d_array(self._allocations)[alloc_records_meta["new_indices"]] return self.replace( wrapper=wrapper_meta["new_wrapper"], alloc_records=new_alloc_records, allocations=new_allocations, ) def resample(self: PortfolioOptimizerT, *args, **kwargs) -> PortfolioOptimizerT: """Perform resampling on `PortfolioOptimizer`.""" new_wrapper = self.wrapper.resample(*args, **kwargs) new_alloc_records = self.alloc_records.resample(*args, **kwargs) return self.replace( wrapper=new_wrapper, alloc_records=new_alloc_records, ) # ############# Class methods ############# # @classmethod def run_allocation_group( cls, wrapper: ArrayWrapper, group_configs: tp.List[dict], group_index: tp.Index, group_idx: int, pre_group_func: tp.Optional[tp.Callable] = None, ) -> tp.Tuple[tp.RecordArray, tp.Array2d]: """Run an allocation group.""" group_config = dict(group_configs[group_idx]) if pre_group_func is not None: pre_group_func(group_config) allocate_func = group_config.pop("allocate_func") every = group_config.pop("every") normalize_every = group_config.pop("normalize_every") at_time = group_config.pop("at_time") start = group_config.pop("start") end = group_config.pop("end") exact_start = group_config.pop("exact_start") on = group_config.pop("on") add_delta = group_config.pop("add_delta") kind = group_config.pop("kind") indexer_method = group_config.pop("indexer_method") indexer_tolerance = group_config.pop("indexer_tolerance") skip_not_found = group_config.pop("skip_not_found") index_points = group_config.pop("index_points") rescale_to = group_config.pop("rescale_to") jitted_loop = group_config.pop("jitted_loop") jitted = group_config.pop("jitted") chunked = group_config.pop("chunked") template_context = group_config.pop("template_context") execute_kwargs = group_config.pop("execute_kwargs") args = group_config.pop("args") kwargs = group_config template_context = merge_dicts( dict( group_configs=group_configs, group_index=group_index, group_idx=group_idx, wrapper=wrapper, allocate_func=allocate_func, every=every, normalize_every=normalize_every, at_time=at_time, start=start, end=end, exact_start=exact_start, on=on, add_delta=add_delta, kind=kind, indexer_method=indexer_method, indexer_tolerance=indexer_tolerance, skip_not_found=skip_not_found, index_points=index_points, rescale_to=rescale_to, jitted_loop=jitted_loop, jitted=jitted, chunked=chunked, execute_kwargs=execute_kwargs, args=args, kwargs=kwargs, ), template_context, ) if index_points is None: get_index_points_kwargs = substitute_templates( dict( every=every, normalize_every=normalize_every, at_time=at_time, start=start, end=end, exact_start=exact_start, on=on, add_delta=add_delta, kind=kind, indexer_method=indexer_method, indexer_tolerance=indexer_tolerance, skip_not_found=skip_not_found, ), template_context, eval_id="get_index_points_defaults", strict=True, ) index_points = wrapper.get_index_points(**get_index_points_kwargs) template_context = merge_dicts( template_context, get_index_points_kwargs, dict(index_points=index_points), ) else: index_points = substitute_templates( index_points, template_context, eval_id="index_points", strict=True, ) index_points = to_1d_array(index_points) template_context = merge_dicts(template_context, dict(index_points=index_points)) if jitted_loop: allocate_func = substitute_templates( allocate_func, template_context, eval_id="allocate_func", strict=True, ) args = substitute_templates(args, template_context, eval_id="args") kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") func = jit_reg.resolve_option(nb.allocate_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) _allocations = func(len(wrapper.columns), index_points, allocate_func, *args, **kwargs) else: tasks = [] keys = [] for i in range(len(index_points)): _template_context = merge_dicts( dict( i=i, index_point=index_points[i], ), template_context, ) _allocate_func = substitute_templates( allocate_func, _template_context, eval_id="allocate_func", strict=True, ) _args = substitute_templates(args, _template_context, eval_id="args") _kwargs = substitute_templates(kwargs, _template_context, eval_id="kwargs") tasks.append(Task(_allocate_func, *_args, **_kwargs)) if isinstance(wrapper.index, pd.DatetimeIndex): keys.append(dt.readable_datetime(wrapper.index[index_points[i]], freq=wrapper.freq)) else: keys.append(str(wrapper.index[index_points[i]])) results = execute(tasks, keys=keys, **execute_kwargs) _allocations = pd.DataFrame(results, columns=wrapper.columns) if isinstance(_allocations.columns, pd.RangeIndex): _allocations = _allocations.values else: _allocations = _allocations[list(wrapper.columns)].values if rescale_to is not None: _allocations = nb.rescale_allocations_nb(_allocations, rescale_to) return nb.prepare_alloc_points_nb(index_points, _allocations, group_idx) @classmethod def from_allocate_func( cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, allocate_func: tp.Callable, *args, every: tp.Union[None, tp.FrequencyLike, Param] = point_idxr_defaults["every"], normalize_every: tp.Union[bool, Param] = point_idxr_defaults["normalize_every"], at_time: tp.Union[None, tp.TimeLike, Param] = point_idxr_defaults["at_time"], start: tp.Union[None, int, tp.DatetimeLike, Param] = point_idxr_defaults["start"], end: tp.Union[None, int, tp.DatetimeLike, Param] = point_idxr_defaults["end"], exact_start: tp.Union[bool, Param] = point_idxr_defaults["exact_start"], on: tp.Union[None, int, tp.DatetimeLike, tp.IndexLike, Param] = point_idxr_defaults["on"], add_delta: tp.Union[None, tp.FrequencyLike, Param] = point_idxr_defaults["add_delta"], kind: tp.Union[None, str, Param] = point_idxr_defaults["kind"], indexer_method: tp.Union[None, str, Param] = point_idxr_defaults["indexer_method"], indexer_tolerance: tp.Union[None, str, Param] = point_idxr_defaults["indexer_tolerance"], skip_not_found: tp.Union[bool, Param] = point_idxr_defaults["skip_not_found"], index_points: tp.Union[None, tp.MaybeSequence[int], Param] = None, rescale_to: tp.Union[None, tp.Tuple[float, float], Param] = None, parameterizer: tp.Optional[tp.MaybeType[Parameterizer]] = None, param_search_kwargs: tp.KwargsLike = None, name_tuple_to_str: tp.Union[None, bool, tp.Callable] = None, group_configs: tp.Union[None, tp.Dict[tp.Hashable, tp.Kwargs], tp.Sequence[tp.Kwargs]] = None, pre_group_func: tp.Optional[tp.Callable] = None, jitted_loop: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, template_context: tp.KwargsLike = None, group_execute_kwargs: tp.KwargsLike = None, execute_kwargs: tp.KwargsLike = None, random_subset: tp.Optional[int] = None, clean_index_kwargs: tp.KwargsLike = None, wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> PortfolioOptimizerT: """Generate allocations from an allocation function. Generates date points and allocates at those points. Similar to `PortfolioOptimizer.from_optimize_func`, but generates points using `vectorbtpro.base.wrapping.ArrayWrapper.get_index_points` and makes each point available as `index_point` in the context. Templates can use the following variables: * `i`: Allocation step * `index_point`: Allocation index If `jitted_loop` is True, see `vectorbtpro.portfolio.pfopt.nb.allocate_meta_nb`. Also, in contrast to `PortfolioOptimizer.from_optimize_func`, creates records of type `vectorbtpro.portfolio.pfopt.records.AllocPoints`. Usage: * Allocate uniformly every day: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull( ... ["MSFT", "AMZN", "AAPL"], ... start="2010-01-01", ... end="2020-01-01" ... ) >>> close = data.get("Close") >>> def uniform_allocate_func(n_cols): ... return np.full(n_cols, 1 / n_cols) >>> pfo = vbt.PortfolioOptimizer.from_allocate_func( ... close.vbt.wrapper, ... uniform_allocate_func, ... close.shape[1] ... ) >>> pfo.allocations symbol MSFT AMZN AAPL Date 2010-01-04 00:00:00-05:00 0.333333 0.333333 0.333333 2010-01-05 00:00:00-05:00 0.333333 0.333333 0.333333 2010-01-06 00:00:00-05:00 0.333333 0.333333 0.333333 2010-01-07 00:00:00-05:00 0.333333 0.333333 0.333333 2010-01-08 00:00:00-05:00 0.333333 0.333333 0.333333 ... ... ... ... 2019-12-24 00:00:00-05:00 0.333333 0.333333 0.333333 2019-12-26 00:00:00-05:00 0.333333 0.333333 0.333333 2019-12-27 00:00:00-05:00 0.333333 0.333333 0.333333 2019-12-30 00:00:00-05:00 0.333333 0.333333 0.333333 2019-12-31 00:00:00-05:00 0.333333 0.333333 0.333333 [2516 rows x 3 columns] ``` * Allocate randomly every first date of the year: ```pycon >>> def random_allocate_func(n_cols): ... weights = np.random.uniform(size=n_cols) ... return weights / weights.sum() >>> pfo = vbt.PortfolioOptimizer.from_allocate_func( ... close.vbt.wrapper, ... random_allocate_func, ... close.shape[1], ... every="AS-JAN" ... ) >>> pfo.allocations symbol MSFT AMZN AAPL Date 2011-01-03 00:00:00+00:00 0.160335 0.122434 0.717231 2012-01-03 00:00:00+00:00 0.071386 0.469564 0.459051 2013-01-02 00:00:00+00:00 0.125853 0.168480 0.705668 2014-01-02 00:00:00+00:00 0.391565 0.169205 0.439231 2015-01-02 00:00:00+00:00 0.115075 0.602844 0.282081 2016-01-04 00:00:00+00:00 0.244070 0.046547 0.709383 2017-01-03 00:00:00+00:00 0.316065 0.335000 0.348935 2018-01-02 00:00:00+00:00 0.422142 0.252154 0.325704 2019-01-02 00:00:00+00:00 0.368748 0.195147 0.436106 ``` * Specify index points manually: ```pycon >>> pfo = vbt.PortfolioOptimizer.from_allocate_func( ... close.vbt.wrapper, ... random_allocate_func, ... close.shape[1], ... index_points=[0, 30, 60] ... ) >>> pfo.allocations symbol MSFT AMZN AAPL Date 2010-01-04 00:00:00+00:00 0.257878 0.308287 0.433835 2010-02-17 00:00:00+00:00 0.090927 0.471980 0.437094 2010-03-31 00:00:00+00:00 0.395855 0.148516 0.455629 ``` * Specify allocations manually: ```pycon >>> def manual_allocate_func(weights): ... return weights >>> pfo = vbt.PortfolioOptimizer.from_allocate_func( ... close.vbt.wrapper, ... manual_allocate_func, ... vbt.RepEval("weights[i]", context=dict(weights=[ ... [1, 0, 0], ... [0, 1, 0], ... [0, 0, 1] ... ])), ... index_points=[0, 30, 60] ... ) >>> pfo.allocations symbol MSFT AMZN AAPL Date 2010-01-04 00:00:00+00:00 1 0 0 2010-02-17 00:00:00+00:00 0 1 0 2010-03-31 00:00:00+00:00 0 0 1 ``` * Use Numba-compiled loop: ```pycon >>> @njit ... def random_allocate_func_nb(i, idx, n_cols): ... weights = np.random.uniform(0, 1, n_cols) ... return weights / weights.sum() >>> pfo = vbt.PortfolioOptimizer.from_allocate_func( ... close.vbt.wrapper, ... random_allocate_func_nb, ... close.shape[1], ... index_points=[0, 30, 60], ... jitted_loop=True ... ) >>> pfo.allocations symbol MSFT AMZN AAPL Date 2010-01-04 00:00:00+00:00 0.231925 0.351085 0.416990 2010-02-17 00:00:00+00:00 0.163050 0.070292 0.766658 2010-03-31 00:00:00+00:00 0.497465 0.500215 0.002319 ``` !!! hint There is no big reason of using the Numba-compiled loop, apart from when having to rebalance many thousands of times. Usually, using a regular Python loop and a Numba-compiled allocation function should suffice. """ from vectorbtpro._settings import settings params_cfg = settings["params"] if parameterizer is None: parameterizer = params_cfg["parameterizer"] if parameterizer is None: parameterizer = Parameterizer param_search_kwargs = merge_dicts(params_cfg["param_search_kwargs"], param_search_kwargs) if group_execute_kwargs is None: group_execute_kwargs = {} if execute_kwargs is None: execute_kwargs = {} if clean_index_kwargs is None: clean_index_kwargs = {} # Prepare group config names gc_names = [] gc_names_none = True n_configs = 0 if group_configs is not None: if isinstance(group_configs, dict): new_group_configs = [] for k, v in group_configs.items(): v = dict(v) v["_name"] = k new_group_configs.append(v) group_configs = new_group_configs else: group_configs = list(group_configs) for i, group_config in enumerate(group_configs): group_config = dict(group_config) if "args" in group_config: for k, arg in enumerate(group_config.pop("args")): group_config[f"args_{k}"] = arg if "kwargs" in group_config: for k, v in enumerate(group_config.pop("kwargs")): group_config[k] = v if "_name" in group_config and group_config["_name"] is not None: gc_names.append(group_config.pop("_name")) gc_names_none = False else: gc_names.append(n_configs) group_configs[i] = group_config n_configs += 1 else: group_configs = [] # Combine parameters paramable_kwargs = { "every": every, "normalize_every": normalize_every, "at_time": at_time, "start": start, "end": end, "exact_start": exact_start, "on": on, "add_delta": add_delta, "kind": kind, "indexer_method": indexer_method, "indexer_tolerance": indexer_tolerance, "skip_not_found": skip_not_found, "index_points": index_points, "rescale_to": rescale_to, **{f"args_{i}": args[i] for i in range(len(args))}, **kwargs, } param_dct = parameterizer.find_params_in_obj(paramable_kwargs, **param_search_kwargs) param_columns = None if len(param_dct) > 0: param_product, param_columns = combine_params( param_dct, random_subset=random_subset, clean_index_kwargs=clean_index_kwargs, name_tuple_to_str=name_tuple_to_str, ) if param_columns is None: n_param_configs = len(param_product[list(param_product.keys())[0]]) param_columns = pd.RangeIndex(stop=n_param_configs, name="param_config") product_group_configs = parameterizer.param_product_to_objs(paramable_kwargs, param_product) if len(group_configs) == 0: group_configs = product_group_configs else: new_group_configs = [] for i in range(len(product_group_configs)): for group_config in group_configs: new_group_config = merge_dicts(product_group_configs[i], group_config) new_group_configs.append(new_group_config) group_configs = new_group_configs # Build group index n_config_params = len(gc_names) if param_columns is not None: if n_config_params == 0 or (n_config_params == 1 and gc_names_none): group_index = param_columns else: group_index = combine_indexes( ( param_columns, pd.Index(gc_names, name="group_config"), ), **clean_index_kwargs, ) else: if n_config_params == 0 or (n_config_params == 1 and gc_names_none): group_index = pd.Index(["group"], name="group") else: group_index = pd.Index(gc_names, name="group_config") # Create group config from arguments if empty if len(group_configs) == 0: single_group = True group_configs.append(dict()) else: single_group = False # Resolve each group groupable_kwargs = { "allocate_func": allocate_func, **paramable_kwargs, "jitted_loop": jitted_loop, "jitted": jitted, "chunked": chunked, "template_context": template_context, "execute_kwargs": execute_kwargs, } new_group_configs = [] for group_config in group_configs: new_group_config = merge_dicts(groupable_kwargs, group_config) _args = () while True: if f"args_{len(_args)}" in new_group_config: _args += (new_group_config.pop(f"args_{len(_args)}"),) else: break new_group_config["args"] = _args new_group_configs.append(new_group_config) group_configs = new_group_configs # Generate allocations tasks = [] for group_idx, group_config in enumerate(group_configs): tasks.append( Task( cls.run_allocation_group, wrapper=wrapper, group_configs=group_configs, group_index=group_index, group_idx=group_idx, pre_group_func=pre_group_func, ) ) group_execute_kwargs = merge_dicts(dict(show_progress=False if single_group else None), group_execute_kwargs) results = execute(tasks, keys=group_index, **group_execute_kwargs) alloc_points, allocations = zip(*results) # Build column hierarchy new_columns = combine_indexes((group_index, wrapper.columns), **clean_index_kwargs) # Create instance wrapper_kwargs = merge_dicts( dict( index=wrapper.index, columns=new_columns, ndim=2, freq=wrapper.freq, column_only_select=False, range_only_select=True, group_select=True, grouped_ndim=1 if single_group else 2, group_by=group_index.names if group_index.nlevels > 1 else group_index.name, allow_enable=False, allow_disable=True, allow_modify=False, ), wrapper_kwargs, ) new_wrapper = ArrayWrapper(**wrapper_kwargs) alloc_points = AllocPoints( ArrayWrapper( index=wrapper.index, columns=new_wrapper.get_columns(), ndim=new_wrapper.get_ndim(), freq=wrapper.freq, column_only_select=False, range_only_select=True, ), np.concatenate(alloc_points), ) allocations = row_stack_arrays(allocations) return cls(new_wrapper, alloc_points, allocations) @classmethod def from_allocations( cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, allocations: tp.ArrayLike, **kwargs, ) -> PortfolioOptimizerT: """Pick allocations from a (flexible) array. Uses `PortfolioOptimizer.from_allocate_func`. If `allocations` is a DataFrame, uses its index as labels. If it's a Series or dict, uses it as a single allocation without index, which by default gets assigned to each index. If it's neither one of the above nor a NumPy array, tries to convert it into a NumPy array. If `allocations` is a NumPy array, uses `vectorbtpro.portfolio.pfopt.nb.pick_idx_allocate_func_nb` and a Numba-compiled loop. Otherwise, uses a regular Python function to pick each allocation (which can be a dict, Series, etc.). Selection of elements is done in a flexible manner, meaning a single element will be applied to all rows, while one-dimensional arrays will be also applied to all rows but also broadcast across columns (as opposed to rows).""" if isinstance(allocations, pd.Series): allocations = allocations.to_dict() if isinstance(allocations, dict): allocations = pd.DataFrame([allocations], columns=wrapper.columns) allocations = allocations.values if isinstance(allocations, pd.DataFrame): kwargs = merge_dicts(dict(on=allocations.index, kind="labels"), kwargs) allocations = allocations.values if not isinstance(allocations, np.ndarray): with WarningsFiltered(): try: new_allocations = np.asarray(allocations) if new_allocations.dtype != object: allocations = new_allocations except Exception as e: pass if isinstance(allocations, np.ndarray): def _resolve_allocations(index_points): target_shape = (len(index_points), len(wrapper.columns)) return broadcast_array_to(allocations, target_shape, expand_axis=0) return cls.from_allocate_func( wrapper, nb.pick_idx_allocate_func_nb, RepFunc(_resolve_allocations), jitted_loop=True, **kwargs, ) def _pick_allocate_func(index_points, i): if not checks.is_sequence(allocations): return allocations if len(allocations) == 1: return allocations[0] if len(index_points) != len(allocations): raise ValueError(f"Allocation array must have {len(index_points)} rows") return allocations[i] return cls.from_allocate_func(wrapper, _pick_allocate_func, Rep("index_points"), Rep("i"), **kwargs) @classmethod def from_initial( cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, allocations: tp.ArrayLike, **kwargs, ) -> PortfolioOptimizerT: """Allocate once at the first index. Uses `PortfolioOptimizer.from_allocations` with `on=0`.""" return cls.from_allocations(wrapper, allocations, on=0, **kwargs) @classmethod def from_filled_allocations( cls: tp.Type[PortfolioOptimizerT], allocations: tp.AnyArray2d, valid_only: bool = True, nonzero_only: bool = True, unique_only: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, **kwargs, ) -> PortfolioOptimizerT: """Pick allocations from an already filled array. Uses `PortfolioOptimizer.from_allocate_func`. Uses `vectorbtpro.portfolio.pfopt.nb.pick_point_allocate_func_nb` and a Numba-compiled loop. Extracts allocation points using `vectorbtpro.portfolio.pfopt.nb.get_alloc_points_nb`.""" if wrapper is None: if checks.is_frame(allocations): wrapper = ArrayWrapper.from_obj(allocations) else: raise TypeError("Wrapper is required if allocations is not a DataFrame") allocations = to_2d_array(allocations, expand_axis=0) if allocations.shape != wrapper.shape_2d: raise ValueError("Allocation array must have the same shape as wrapper") on = nb.get_alloc_points_nb( allocations, valid_only=valid_only, nonzero_only=nonzero_only, unique_only=unique_only, ) kwargs = merge_dicts(dict(on=on), kwargs) return cls.from_allocate_func( wrapper, nb.pick_point_allocate_func_nb, allocations, jitted_loop=True, **kwargs, ) @classmethod def from_uniform(cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, **kwargs) -> PortfolioOptimizerT: """Generate uniform allocations. Uses `PortfolioOptimizer.from_allocate_func`.""" def _uniform_allocate_func(): return np.full(wrapper.shape_2d[1], 1 / wrapper.shape_2d[1]) return cls.from_allocate_func(wrapper, _uniform_allocate_func, **kwargs) @classmethod def from_random( cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, direction: tp.Union[str, int] = "longonly", n: tp.Optional[int] = None, seed: tp.Optional[int] = None, **kwargs, ) -> PortfolioOptimizerT: """Generate random allocations. Uses `PortfolioOptimizer.from_allocate_func`. Uses `vectorbtpro.portfolio.pfopt.nb.random_allocate_func_nb` and a Numba-compiled loop.""" if isinstance(direction, str): direction = map_enum_fields(direction, Direction) if seed is not None: set_seed_nb(seed) return cls.from_allocate_func( wrapper, nb.random_allocate_func_nb, wrapper.shape_2d[1], direction, n, jitted_loop=True, **kwargs, ) @classmethod def from_universal_algo( cls: tp.Type[PortfolioOptimizerT], algo: tp.Union[str, tp.Type[AlgoT], AlgoT, AlgoResultT], S: tp.Optional[tp.AnyArray2d] = None, n_jobs: int = 1, log_progress: bool = False, valid_only: bool = True, nonzero_only: bool = True, unique_only: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, **kwargs, ) -> PortfolioOptimizerT: """Generate allocations using [Universal Portfolios](https://github.com/Marigold/universal-portfolios). `S` can be any price, while `algo` must be either an attribute of the package, subclass of `universal.algo.Algo`, instance of `universal.algo.Algo`, or instance of `universal.result.AlgoResult`. Extracts allocation points using `vectorbtpro.portfolio.pfopt.nb.get_alloc_points_nb`.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("universal") from universal.algo import Algo from universal.result import AlgoResult if wrapper is None: if S is None or not checks.is_frame(S): raise TypeError("Wrapper is required if allocations is not a DataFrame") else: wrapper = ArrayWrapper.from_obj(S) def _pre_group_func(group_config, _algo=algo): _ = group_config.pop("args", ()) if isinstance(_algo, str): import universal.algos _algo = getattr(universal.algos, _algo) if isinstance(_algo, type) and issubclass(_algo, Algo): reserved_arg_names = get_func_arg_names(cls.from_allocate_func) algo_keys = set(group_config.keys()).difference(reserved_arg_names) algo_kwargs = {} for k in algo_keys: algo_kwargs[k] = group_config.pop(k) _algo = _algo(**algo_kwargs) if isinstance(_algo, Algo): if S is None: raise ValueError("S is required") _algo = _algo.run(S, n_jobs=n_jobs, log_progress=log_progress) if isinstance(_algo, AlgoResult): weights = _algo.weights[wrapper.columns].values else: raise TypeError(f"Algo {_algo} not supported") if "on" not in kwargs: group_config["on"] = nb.get_alloc_points_nb( weights, valid_only=valid_only, nonzero_only=nonzero_only, unique_only=unique_only, ) group_config["args"] = (weights,) return cls.from_allocate_func( wrapper, nb.pick_point_allocate_func_nb, jitted_loop=True, pre_group_func=_pre_group_func, **kwargs, ) @classmethod def run_optimization_group( cls, wrapper: ArrayWrapper, group_configs: tp.List[dict], group_index: tp.Index, group_idx: int, pre_group_func: tp.Optional[tp.Callable] = None, silence_warnings: bool = False, ) -> tp.Tuple[tp.RecordArray, tp.Array2d]: """Run an optimization group.""" group_config = dict(group_configs[group_idx]) if pre_group_func is not None: pre_group_func(group_config) optimize_func = group_config.pop("optimize_func") every = group_config.pop("every") normalize_every = group_config.pop("normalize_every") split_every = group_config.pop("split_every") start_time = group_config.pop("start_time") end_time = group_config.pop("end_time") lookback_period = group_config.pop("lookback_period") start = group_config.pop("start") end = group_config.pop("end") exact_start = group_config.pop("exact_start") fixed_start = group_config.pop("fixed_start") closed_start = group_config.pop("closed_start") closed_end = group_config.pop("closed_end") add_start_delta = group_config.pop("add_start_delta") add_end_delta = group_config.pop("add_end_delta") kind = group_config.pop("kind") skip_not_found = group_config.pop("skip_not_found") index_ranges = group_config.pop("index_ranges") index_loc = group_config.pop("index_loc") rescale_to = group_config.pop("rescale_to") alloc_wait = group_config.pop("alloc_wait") splitter_cls = group_config.pop("splitter_cls") eval_id = group_config.pop("eval_id") jitted_loop = group_config.pop("jitted_loop") jitted = group_config.pop("jitted") chunked = group_config.pop("chunked") template_context = group_config.pop("template_context") execute_kwargs = group_config.pop("execute_kwargs") args = group_config.pop("args") kwargs = group_config if splitter_cls is None: splitter_cls = Splitter template_context = merge_dicts( dict( group_configs=group_configs, group_index=group_index, group_idx=group_idx, wrapper=wrapper, optimize_func=optimize_func, every=every, normalize_every=normalize_every, split_every=split_every, start_time=start_time, end_time=end_time, lookback_period=lookback_period, start=start, end=end, exact_start=exact_start, fixed_start=fixed_start, closed_start=closed_start, closed_end=closed_end, add_start_delta=add_start_delta, add_end_delta=add_end_delta, kind=kind, skip_not_found=skip_not_found, index_ranges=index_ranges, index_loc=index_loc, rescale_to=rescale_to, alloc_wait=alloc_wait, splitter_cls=splitter_cls, eval_id=eval_id, jitted_loop=jitted_loop, jitted=jitted, chunked=chunked, args=args, kwargs=kwargs, execute_kwargs=execute_kwargs, ), template_context, ) if index_ranges is None: get_index_ranges_defaults = substitute_templates( dict( every=every, normalize_every=normalize_every, split_every=split_every, start_time=start_time, end_time=end_time, lookback_period=lookback_period, start=start, end=end, exact_start=exact_start, fixed_start=fixed_start, closed_start=closed_start, closed_end=closed_end, add_start_delta=add_start_delta, add_end_delta=add_end_delta, kind=kind, skip_not_found=skip_not_found, jitted=jitted, ), template_context, eval_id="get_index_ranges_defaults", strict=True, ) index_ranges = wrapper.get_index_ranges(**get_index_ranges_defaults) template_context = merge_dicts( template_context, get_index_ranges_defaults, dict(index_ranges=index_ranges), ) else: index_ranges = substitute_templates( index_ranges, template_context, eval_id="index_ranges", strict=True, ) if isinstance(index_ranges, np.ndarray): index_ranges = (index_ranges[:, 0], index_ranges[:, 1]) elif not isinstance(index_ranges[0], np.ndarray) and not isinstance(index_ranges[1], np.ndarray): index_ranges = to_2d_array(index_ranges, expand_axis=0) index_ranges = (index_ranges[:, 0], index_ranges[:, 1]) template_context = merge_dicts(template_context, dict(index_ranges=index_ranges)) if index_loc is not None: index_loc = substitute_templates( index_loc, template_context, eval_id="index_loc", strict=True, ) index_loc = to_1d_array(index_loc) template_context = merge_dicts(template_context, dict(index_loc=index_loc)) if jitted_loop: optimize_func = substitute_templates( optimize_func, template_context, eval_id="optimize_func", strict=True, ) args = substitute_templates(args, template_context, eval_id="args") kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs") func = jit_reg.resolve_option(nb.optimize_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) _allocations = func( len(wrapper.columns), index_ranges[0], index_ranges[1], optimize_func, *args, **kwargs, ) else: tasks = [] keys = [] for i in range(len(index_ranges[0])): index_slice = slice(max(0, index_ranges[0][i]), index_ranges[1][i]) _template_context = merge_dicts( dict( i=i, index_slice=index_slice, index_start=index_ranges[0][i], index_end=index_ranges[1][i], ), template_context, ) _optimize_func = substitute_templates( optimize_func, _template_context, eval_id="optimize_func", strict=True, ) _args = substitute_templates(args, _template_context, eval_id="args") _kwargs = substitute_templates(kwargs, _template_context, eval_id="kwargs") if has_annotatables(_optimize_func): ann_args = annotate_args( _optimize_func, _args, _kwargs, attach_annotations=True, ) flat_ann_args = flatten_ann_args(ann_args) flat_ann_args = splitter_cls.parse_and_inject_takeables(flat_ann_args, eval_id=eval_id) ann_args = unflatten_ann_args(flat_ann_args) _args, _kwargs = ann_args_to_args(ann_args) __args = () for v in _args: if isinstance(v, Takeable) and v.meets_eval_id(eval_id): v = splitter_cls.take_range_from_takeable( v, index_slice, template_context=_template_context, silence_warnings=silence_warnings, index=wrapper.index, freq=wrapper.freq, ) __args += (v,) __kwargs = {} for k, v in _kwargs.items(): if isinstance(v, Takeable) and v.meets_eval_id(eval_id): v = splitter_cls.take_range_from_takeable( v, index_slice, template_context=_template_context, silence_warnings=silence_warnings, index=wrapper.index, freq=wrapper.freq, ) __kwargs[k] = v tasks.append(Task(_optimize_func, *__args, **__kwargs)) if isinstance(wrapper.index, pd.DatetimeIndex): keys.append( "{} → {}".format( dt.readable_datetime(wrapper.index[index_ranges[0][i]], freq=wrapper.freq), dt.readable_datetime(wrapper.index[index_ranges[1][i] - 1], freq=wrapper.freq), ) ) else: keys.append( "{} → {}".format( str(wrapper.index[index_ranges[0][i]]), str(wrapper.index[index_ranges[1][i] - 1]), ) ) results = execute(tasks, keys=keys, **execute_kwargs) _allocations = pd.DataFrame(results, columns=wrapper.columns) if isinstance(_allocations.columns, pd.RangeIndex): _allocations = _allocations.values else: _allocations = _allocations[list(wrapper.columns)].values if rescale_to is not None: _allocations = nb.rescale_allocations_nb(_allocations, rescale_to) if index_loc is None: alloc_wait = substitute_templates( alloc_wait, template_context, eval_id="alloc_wait", strict=True, ) alloc_idx = index_ranges[1] - 1 + alloc_wait else: alloc_idx = index_loc status = np.where( alloc_idx >= len(wrapper.index), RangeStatus.Open, RangeStatus.Closed, ) return nb.prepare_alloc_ranges_nb( index_ranges[0], index_ranges[1], alloc_idx, status, _allocations, group_idx, ) @classmethod def from_optimize_func( cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, optimize_func: tp.Callable, *args, every: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["every"], normalize_every: tp.Union[bool, Param] = range_idxr_defaults["normalize_every"], split_every: tp.Union[bool, Param] = range_idxr_defaults["split_every"], start_time: tp.Union[None, tp.TimeLike, Param] = range_idxr_defaults["start_time"], end_time: tp.Union[None, tp.TimeLike, Param] = range_idxr_defaults["end_time"], lookback_period: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["lookback_period"], start: tp.Union[None, int, tp.DatetimeLike, tp.IndexLike, Param] = range_idxr_defaults["start"], end: tp.Union[None, int, tp.DatetimeLike, tp.IndexLike, Param] = range_idxr_defaults["end"], exact_start: tp.Union[bool, Param] = range_idxr_defaults["exact_start"], fixed_start: tp.Union[bool, Param] = range_idxr_defaults["fixed_start"], closed_start: tp.Union[bool, Param] = range_idxr_defaults["closed_start"], closed_end: tp.Union[bool, Param] = range_idxr_defaults["closed_end"], add_start_delta: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["add_start_delta"], add_end_delta: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["add_end_delta"], kind: tp.Union[None, str, Param] = range_idxr_defaults["kind"], skip_not_found: tp.Union[bool, Param] = range_idxr_defaults["skip_not_found"], index_ranges: tp.Union[None, tp.MaybeSequence[tp.MaybeSequence[int]], Param] = None, index_loc: tp.Union[None, tp.MaybeSequence[int], Param] = None, rescale_to: tp.Union[None, tp.Tuple[float, float], Param] = None, alloc_wait: tp.Union[int, Param] = 1, parameterizer: tp.Optional[tp.MaybeType[Parameterizer]] = None, param_search_kwargs: tp.KwargsLike = None, name_tuple_to_str: tp.Union[None, bool, tp.Callable] = None, group_configs: tp.Union[None, tp.Dict[tp.Hashable, tp.Kwargs], tp.Sequence[tp.Kwargs]] = None, pre_group_func: tp.Optional[tp.Callable] = None, splitter_cls: tp.Optional[tp.Type[Splitter]] = None, eval_id: tp.Optional[tp.Hashable] = None, jitted_loop: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, template_context: tp.KwargsLike = None, group_execute_kwargs: tp.KwargsLike = None, execute_kwargs: tp.KwargsLike = None, random_subset: tp.Optional[int] = None, clean_index_kwargs: tp.KwargsLike = None, wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> PortfolioOptimizerT: """Generate allocations from an optimization function. Generates date ranges, performs optimization on the subset of data that belongs to each date range, and allocates at the end of each range. This is a parameterized method that allows testing multiple combinations on most arguments. First, it checks whether any of the arguments is wrapped with `vectorbtpro.utils.params.Param` and combines their values. It then combines them over `group_configs`, if provided. Before execution, it additionally processes the group config using `pre_group_func`. It then resolves the date ranges, either using the ready-to-use `index_ranges` or by passing all the arguments ranging from `every` to `jitted` to `vectorbtpro.base.wrapping.ArrayWrapper.get_index_ranges`. The optimization function `optimize_func` is then called on each date range by first substituting any templates found in `*args` and `**kwargs`. To forward any reserved arguments such as `jitted` to the optimization function, specify their names in `forward_args` and `forward_kwargs`. !!! note Make sure to use vectorbt's own templates to select the current date range (available as `index_slice` in the context mapping) from each array. If `jitted_loop` is True, see `vectorbtpro.portfolio.pfopt.nb.optimize_meta_nb`. Otherwise, must take template-substituted `*args` and `**kwargs`, and return an array or dictionary with asset allocations (also empty). Templates can use the following variables: * `i`: Optimization step * `index_start`: Optimization start index (including) * `index_end`: Optimization end index (excluding) * `index_slice`: `slice(index_start, index_end)` !!! note When `jitted_loop` is True and in case of multiple groups, use templates to substitute by the current group index (available as `group_idx` in the context mapping). All allocations of all groups are stacked into one big 2-dim array where columns are assets and rows are allocations. Furthermore, date ranges are used to fill a record array of type `vectorbtpro.portfolio.pfopt.records.AllocRanges` that acts as an indexer for allocations. For example, the field `col` stores the group index corresponding to each allocation. Since this record array does not hold any information on assets themselves, it has its own wrapper that holds groups instead of columns, while the wrapper of the `PortfolioOptimizer` instance contains regular columns grouped by groups. Usage: * Allocate once: ```pycon >>> from vectorbtpro import * >>> data = vbt.YFData.pull( ... ["MSFT", "AMZN", "AAPL"], ... start="2010-01-01", ... end="2020-01-01" ... ) >>> close = data.get("Close") >>> def optimize_func(df): ... sharpe = df.mean() / df.std() ... return sharpe / sharpe.sum() >>> df_arg = vbt.RepEval("close.iloc[index_slice]", context=dict(close=close)) >>> pfo = vbt.PortfolioOptimizer.from_optimize_func( ... close.vbt.wrapper, ... optimize_func, ... df_arg, ... end="2015-01-01" ... ) >>> pfo.allocations symbol MSFT AMZN AAPL alloc_group Date group 2015-01-02 00:00:00+00:00 0.402459 0.309351 0.288191 ``` * Allocate every first date of the year: ```pycon >>> pfo = vbt.PortfolioOptimizer.from_optimize_func( ... close.vbt.wrapper, ... optimize_func, ... df_arg, ... every="AS-JAN" ... ) >>> pfo.allocations symbol MSFT AMZN AAPL alloc_group Date group 2011-01-03 00:00:00+00:00 0.480693 0.257317 0.261990 2012-01-03 00:00:00+00:00 0.489893 0.215381 0.294727 2013-01-02 00:00:00+00:00 0.540165 0.228755 0.231080 2014-01-02 00:00:00+00:00 0.339649 0.273996 0.386354 2015-01-02 00:00:00+00:00 0.350406 0.418638 0.230956 2016-01-04 00:00:00+00:00 0.332212 0.141090 0.526698 2017-01-03 00:00:00+00:00 0.390852 0.225379 0.383769 2018-01-02 00:00:00+00:00 0.337711 0.317683 0.344606 2019-01-02 00:00:00+00:00 0.411852 0.282680 0.305468 ``` * Specify index ranges manually: ```pycon >>> pfo = vbt.PortfolioOptimizer.from_optimize_func( ... close.vbt.wrapper, ... optimize_func, ... df_arg, ... index_ranges=[ ... (0, 30), ... (30, 60), ... (60, 90) ... ] ... ) >>> pfo.allocations symbol MSFT AMZN AAPL alloc_group Date group 2010-02-16 00:00:00+00:00 0.340641 0.285897 0.373462 2010-03-30 00:00:00+00:00 0.596392 0.206317 0.197291 2010-05-12 00:00:00+00:00 0.437481 0.283160 0.279358 ``` * Test multiple combinations of one argument: ```pycon >>> pfo = vbt.PortfolioOptimizer.from_optimize_func( ... close.vbt.wrapper, ... optimize_func, ... df_arg, ... every="AS-JAN", ... start="2015-01-01", ... lookback_period=vbt.Param(["3MS", "6MS"]) ... ) >>> pfo.allocations symbol MSFT AMZN AAPL lookback_period Date 3MS 2016-01-04 00:00:00+00:00 0.282725 0.234970 0.482305 2017-01-03 00:00:00+00:00 0.318100 0.269355 0.412545 2018-01-02 00:00:00+00:00 0.387499 0.236432 0.376068 2019-01-02 00:00:00+00:00 0.575464 0.254808 0.169728 6MS 2016-01-04 00:00:00+00:00 0.265035 0.198619 0.536346 2017-01-03 00:00:00+00:00 0.314144 0.409020 0.276836 2018-01-02 00:00:00+00:00 0.322741 0.282639 0.394621 2019-01-02 00:00:00+00:00 0.565691 0.234760 0.199549 ``` * Test multiple cross-argument combinations: ```pycon >>> pfo = vbt.PortfolioOptimizer.from_optimize_func( ... close.vbt.wrapper, ... optimize_func, ... df_arg, ... every="AS-JAN", ... group_configs=[ ... dict(start="2015-01-01"), ... dict(start="2019-06-01", every="MS"), ... dict(end="2014-01-01") ... ] ... ) >>> pfo.allocations symbol MSFT AMZN AAPL group_config Date 0 2016-01-04 00:00:00+00:00 0.332212 0.141090 0.526698 2017-01-03 00:00:00+00:00 0.390852 0.225379 0.383769 2018-01-02 00:00:00+00:00 0.337711 0.317683 0.344606 2019-01-02 00:00:00+00:00 0.411852 0.282680 0.305468 1 2019-07-01 00:00:00+00:00 0.351461 0.327334 0.321205 2019-08-01 00:00:00+00:00 0.418411 0.249799 0.331790 2019-09-03 00:00:00+00:00 0.400439 0.374044 0.225517 2019-10-01 00:00:00+00:00 0.509387 0.250497 0.240117 2019-11-01 00:00:00+00:00 0.349983 0.469181 0.180835 2019-12-02 00:00:00+00:00 0.260437 0.380563 0.359000 2 2012-01-03 00:00:00+00:00 0.489892 0.215381 0.294727 2013-01-02 00:00:00+00:00 0.540165 0.228755 0.231080 2014-01-02 00:00:00+00:00 0.339649 0.273997 0.386354 ``` * Use Numba-compiled loop: ```pycon >>> @njit ... def optimize_func_nb(i, from_idx, to_idx, close): ... mean = vbt.nb.nanmean_nb(close[from_idx:to_idx]) ... std = vbt.nb.nanstd_nb(close[from_idx:to_idx]) ... sharpe = mean / std ... return sharpe / np.sum(sharpe) >>> pfo = vbt.PortfolioOptimizer.from_optimize_func( ... close.vbt.wrapper, ... optimize_func_nb, ... np.asarray(close), ... index_ranges=[ ... (0, 30), ... (30, 60), ... (60, 90) ... ], ... jitted_loop=True ... ) >>> pfo.allocations symbol MSFT AMZN AAPL Date 2010-02-17 00:00:00+00:00 0.336384 0.289598 0.374017 2010-03-31 00:00:00+00:00 0.599417 0.207158 0.193425 2010-05-13 00:00:00+00:00 0.434084 0.281246 0.284670 ``` !!! hint There is no big reason of using the Numba-compiled loop, apart from when having to rebalance many thousands of times. Usually, using a regular Python loop and a Numba-compiled optimization function suffice. """ from vectorbtpro._settings import settings params_cfg = settings["params"] if parameterizer is None: parameterizer = params_cfg["parameterizer"] if parameterizer is None: parameterizer = Parameterizer param_search_kwargs = merge_dicts(params_cfg["param_search_kwargs"], param_search_kwargs) if group_execute_kwargs is None: group_execute_kwargs = {} if execute_kwargs is None: execute_kwargs = {} if clean_index_kwargs is None: clean_index_kwargs = {} # Prepare group config names gc_names = [] gc_names_none = True n_configs = 0 if group_configs is not None: group_configs = list(group_configs) for i, group_config in enumerate(group_configs): if isinstance(group_configs, dict): new_group_configs = [] for k, v in group_configs.items(): v = dict(v) v["_name"] = k new_group_configs.append(v) group_configs = new_group_configs else: group_configs = list(group_configs) if "args" in group_config: for k, arg in enumerate(group_config.pop("args")): group_config[f"args_{k}"] = arg if "kwargs" in group_config: for k, v in enumerate(group_config.pop("kwargs")): group_config[k] = v if "_name" in group_config and group_config["_name"] is not None: gc_names.append(group_config.pop("_name")) gc_names_none = False else: gc_names.append(n_configs) group_configs[i] = group_config n_configs += 1 else: group_configs = [] # Combine parameters paramable_kwargs = { "every": every, "normalize_every": normalize_every, "split_every": split_every, "start_time": start_time, "end_time": end_time, "lookback_period": lookback_period, "start": start, "end": end, "exact_start": exact_start, "fixed_start": fixed_start, "closed_start": closed_start, "closed_end": closed_end, "add_start_delta": add_start_delta, "add_end_delta": add_end_delta, "kind": kind, "skip_not_found": skip_not_found, "index_ranges": index_ranges, "index_loc": index_loc, "rescale_to": rescale_to, "alloc_wait": alloc_wait, **{f"args_{i}": args[i] for i in range(len(args))}, **kwargs, } param_dct = parameterizer.find_params_in_obj(paramable_kwargs, **param_search_kwargs) param_columns = None if len(param_dct) > 0: param_product, param_columns = combine_params( param_dct, random_subset=random_subset, clean_index_kwargs=clean_index_kwargs, name_tuple_to_str=name_tuple_to_str, ) if param_columns is None: n_param_configs = len(param_product[list(param_product.keys())[0]]) param_columns = pd.RangeIndex(stop=n_param_configs, name="param_config") product_group_configs = parameterizer.param_product_to_objs(paramable_kwargs, param_product) if len(group_configs) == 0: group_configs = product_group_configs else: new_group_configs = [] for i in range(len(product_group_configs)): for group_config in group_configs: new_group_config = merge_dicts(product_group_configs[i], group_config) new_group_configs.append(new_group_config) group_configs = new_group_configs # Build group index n_config_params = len(gc_names) if param_columns is not None: if n_config_params == 0 or (n_config_params == 1 and gc_names_none): group_index = param_columns else: group_index = combine_indexes( ( param_columns, pd.Index(gc_names, name="group_config"), ), **clean_index_kwargs, ) else: if n_config_params == 0 or (n_config_params == 1 and gc_names_none): group_index = pd.Index(["group"], name="group") else: group_index = pd.Index(gc_names, name="group_config") # Create group config from arguments if empty if len(group_configs) == 0: single_group = True group_configs.append(dict()) else: single_group = False # Resolve each group groupable_kwargs = { "optimize_func": optimize_func, **paramable_kwargs, "splitter_cls": splitter_cls, "eval_id": eval_id, "jitted_loop": jitted_loop, "jitted": jitted, "chunked": chunked, "template_context": template_context, "execute_kwargs": execute_kwargs, } new_group_configs = [] for group_config in group_configs: new_group_config = merge_dicts(groupable_kwargs, group_config) _args = () while True: if f"args_{len(_args)}" in new_group_config: _args += (new_group_config.pop(f"args_{len(_args)}"),) else: break new_group_config["args"] = _args new_group_configs.append(new_group_config) group_configs = new_group_configs # Generate allocations tasks = [] for group_idx, group_config in enumerate(group_configs): tasks.append( Task( cls.run_optimization_group, wrapper=wrapper, group_configs=group_configs, group_index=group_index, group_idx=group_idx, pre_group_func=pre_group_func, ) ) group_execute_kwargs = merge_dicts(dict(show_progress=False if single_group else None), group_execute_kwargs) results = execute(tasks, keys=group_index, **group_execute_kwargs) alloc_ranges, allocations = zip(*results) # Build column hierarchy new_columns = combine_indexes((group_index, wrapper.columns), **clean_index_kwargs) # Create instance wrapper_kwargs = merge_dicts( dict( index=wrapper.index, columns=new_columns, ndim=2, freq=wrapper.freq, column_only_select=False, range_only_select=True, group_select=True, grouped_ndim=1 if single_group else 2, group_by=group_index.names if group_index.nlevels > 1 else group_index.name, allow_enable=False, allow_disable=True, allow_modify=False, ), wrapper_kwargs, ) new_wrapper = ArrayWrapper(**wrapper_kwargs) alloc_ranges = AllocRanges( ArrayWrapper( index=wrapper.index, columns=new_wrapper.get_columns(), ndim=new_wrapper.get_ndim(), freq=wrapper.freq, column_only_select=False, range_only_select=True, ), np.concatenate(alloc_ranges), ) allocations = row_stack_arrays(allocations) return cls(new_wrapper, alloc_ranges, allocations) @classmethod def from_pypfopt( cls: tp.Type[PortfolioOptimizerT], wrapper: tp.Optional[ArrayWrapper] = None, **kwargs, ) -> PortfolioOptimizerT: """`PortfolioOptimizer.from_optimize_func` applied on `pypfopt_optimize`. If a wrapper is not provided, parses the wrapper from `prices` or `returns`, if provided.""" if wrapper is None: if "prices" in kwargs: wrapper = ArrayWrapper.from_obj(kwargs["prices"]) elif "returns" in kwargs: wrapper = ArrayWrapper.from_obj(kwargs["returns"]) else: raise TypeError("Must provide a wrapper if price and returns are not set") else: checks.assert_instance_of(wrapper, ArrayWrapper, arg_name="wrapper") if "prices" in kwargs and not isinstance(kwargs["prices"], CustomTemplate): kwargs["prices"] = RepFunc(lambda index_slice, _prices=kwargs["prices"]: _prices.iloc[index_slice]) if "returns" in kwargs and not isinstance(kwargs["returns"], CustomTemplate): kwargs["returns"] = RepFunc(lambda index_slice, _returns=kwargs["returns"]: _returns.iloc[index_slice]) return cls.from_optimize_func(wrapper, pypfopt_optimize, **kwargs) @classmethod def from_riskfolio( cls: tp.Type[PortfolioOptimizerT], returns: tp.AnyArray2d, wrapper: tp.Optional[ArrayWrapper] = None, **kwargs, ) -> PortfolioOptimizerT: """`PortfolioOptimizer.from_optimize_func` applied on Riskfolio-Lib.""" if wrapper is None: if not isinstance(returns, CustomTemplate): wrapper = ArrayWrapper.from_obj(returns) else: raise TypeError("Must provide a wrapper if returns are a template") else: checks.assert_instance_of(wrapper, ArrayWrapper, arg_name="wrapper") if not isinstance(returns, CustomTemplate): returns = RepFunc(lambda index_slice, _returns=returns: _returns.iloc[index_slice]) return cls.from_optimize_func(wrapper, riskfolio_optimize, returns, **kwargs) # ############# Properties ############# # @property def alloc_records(self) -> tp.Union[AllocRanges, AllocPoints]: """Allocation ranges of type `vectorbtpro.portfolio.pfopt.records.AllocRanges` or points of type `vectorbtpro.portfolio.pfopt.records.AllocPoints`.""" return self._alloc_records def get_allocations(self, squeeze_groups: bool = True) -> tp.Frame: """Get a DataFrame with allocation groups concatenated along the index axis.""" idx_arr = self.alloc_records.get_field_arr("idx") group_arr = self.alloc_records.col_arr allocations = self._allocations if isinstance(self.alloc_records, AllocRanges): closed_mask = self.alloc_records.get_field_arr("status") == RangeStatus.Closed idx_arr = idx_arr[closed_mask] group_arr = group_arr[closed_mask] allocations = allocations[closed_mask] if squeeze_groups and self.wrapper.grouped_ndim == 1: index = self.wrapper.index[idx_arr] else: index = stack_indexes((self.alloc_records.wrapper.columns[group_arr], self.wrapper.index[idx_arr])) n_group_levels = self.wrapper.grouper.get_index().nlevels columns = self.wrapper.columns.droplevel(tuple(range(n_group_levels))).unique() return pd.DataFrame(allocations, index=index, columns=columns) @property def allocations(self) -> tp.Frame: """Calls `PortfolioOptimizer.get_allocations` with default arguments.""" return self.get_allocations() @property def mean_allocation(self) -> tp.Series: """Get the mean allocation per column.""" group_level_names = self.wrapper.grouper.get_index().names return self.get_allocations(squeeze_groups=False).groupby(group_level_names).mean().transpose() def fill_allocations( self, dropna: tp.Optional[str] = None, fill_value: tp.Scalar = np.nan, wrap_kwargs: tp.KwargsLike = None, squeeze_groups: bool = True, ) -> tp.Frame: """Fill an empty DataFrame with allocations. Set `dropna` to 'all' to remove all NaN rows, or to 'head' to remove any rows coming before the first allocation.""" if wrap_kwargs is None: wrap_kwargs = {} out = self.wrapper.fill(fill_value, group_by=False, **wrap_kwargs) idx_arr = self.alloc_records.get_field_arr("idx") group_arr = self.alloc_records.col_arr allocations = self._allocations if isinstance(self.alloc_records, AllocRanges): status_arr = self.alloc_records.get_field_arr("status") closed_mask = status_arr == RangeStatus.Closed idx_arr = idx_arr[closed_mask] group_arr = group_arr[closed_mask] allocations = allocations[closed_mask] for g in range(len(self.alloc_records.wrapper.columns)): group_mask = group_arr == g index_mask = np.full(len(self.wrapper.index), False) index_mask[idx_arr[group_mask]] = True column_mask = self.wrapper.grouper.get_groups() == g out.loc[index_mask, column_mask] = allocations[group_mask] if dropna is not None: if dropna.lower() == "all": out = out.dropna(how="all") elif dropna.lower() == "head": out = out.iloc[idx_arr.min() :] else: raise ValueError(f"Invalid dropna: '{dropna}'") if squeeze_groups and self.wrapper.grouped_ndim == 1: n_group_levels = self.wrapper.grouper.get_index().nlevels out = out.droplevel(tuple(range(n_group_levels)), axis=1) return out @property def filled_allocations(self) -> tp.Frame: """Calls `PortfolioOptimizer.fill_allocations` with default arguments.""" return self.fill_allocations() # ############# Simulation ############# # def simulate(self, close: tp.Union[tp.ArrayLike, Data], **kwargs) -> PortfolioT: """Run `vectorbtpro.portfolio.base.Portfolio.from_optimizer` on this instance.""" from vectorbtpro.portfolio.base import Portfolio return Portfolio.from_optimizer(close, self, **kwargs) # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `PortfolioOptimizer.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.pfopt`.""" from vectorbtpro._settings import settings pfopt_stats_cfg = settings["pfopt"]["stats"] return merge_dicts(Analyzable.stats_defaults.__get__(self), pfopt_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), total_records=dict(title="Total Records", calc_func="alloc_records.count", tags="alloc_records"), coverage=dict( title="Coverage", calc_func="alloc_records.get_coverage", overlapping=False, check_alloc_ranges=True, tags=["alloc_ranges", "coverage"], ), overlap_coverage=dict( title="Overlap Coverage", calc_func="alloc_records.get_coverage", overlapping=True, check_alloc_ranges=True, tags=["alloc_ranges", "coverage"], ), mean_allocation=dict( title="Mean Allocation", calc_func="mean_allocation", post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"), tags="allocations", ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, column: tp.Optional[tp.Label] = None, dropna: tp.Optional[str] = "head", line_shape: str = "hv", plot_rb_dates: tp.Optional[bool] = None, trace_kwargs: tp.KwargsLikeSequence = None, add_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot allocations. Args: column (str): Name of the allocation group to plot. dropna (int): See `PortfolioOptimizer.fill_allocations`. line_shape (str): Line shape. plot_rb_dates (bool): Whether to plot rebalancing dates. Defaults to True if there are no more than 20 rebalancing dates. trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`. add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` for rebalancing dates. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: * Continuing with the examples under `PortfolioOptimizer.from_optimize_func`: ```pycon >>> from vectorbtpro import * >>> pfo = vbt.PortfolioOptimizer.from_random( ... vbt.ArrayWrapper( ... index=pd.date_range("2020-01-01", "2021-01-01"), ... columns=["MSFT", "AMZN", "AAPL"], ... ndim=2 ... ), ... every="MS", ... seed=40 ... ) >>> pfo.plot().show() ``` ![](/assets/images/api/pfopt_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/pfopt_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from vectorbtpro.utils.figure import make_figure self_group = self.select_col(column=column) if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) if self_group.alloc_records.count() > 0: filled_allocations = self_group.fill_allocations(dropna=dropna).ffill() fig = filled_allocations.vbt.areaplot( line_shape=line_shape, trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if plot_rb_dates is None or (isinstance(plot_rb_dates, bool) and plot_rb_dates): rb_dates = self_group.allocations.index if plot_rb_dates is None: plot_rb_dates = len(rb_dates) <= 20 if plot_rb_dates: add_shape_kwargs = merge_dicts( dict( type="line", line=dict( color=fig.layout.template.layout.plot_bgcolor, dash="dot", width=1, ), xref="x", yref="paper", y0=0, y1=1, ), add_shape_kwargs, ) for rb_date in rb_dates: fig.add_shape(x0=rb_date, x1=rb_date, **add_shape_kwargs) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `PortfolioOptimizer.plots`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.pfopt`.""" from vectorbtpro._settings import settings pfopt_plots_cfg = settings["pfopt"]["plots"] return merge_dicts(Analyzable.plots_defaults.__get__(self), pfopt_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( alloc_ranges=dict( title="Allocation Ranges", plot_func="alloc_records.plot", check_alloc_ranges=True, tags="alloc_ranges", ), plot=dict( title="Allocations", plot_func="plot", tags="allocations", ), ) ) @property def subplots(self) -> Config: return self._subplots PortfolioOptimizer.override_metrics_doc(__pdoc__) PortfolioOptimizer.override_subplots_doc(__pdoc__) PFO = PortfolioOptimizer """Shortcut for `PortfolioOptimizer`.""" __pdoc__["PFO"] = False # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for portfolio optimization.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.portfolio.enums import Direction, alloc_point_dt, alloc_range_dt from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.array_ import rescale_nb __all__ = [] @register_jitted(cache=True) def get_alloc_points_nb( filled_allocations: tp.Array2d, valid_only: bool = True, nonzero_only: bool = True, unique_only: bool = True, ) -> tp.Array1d: """Get allocation points from filled allocations. If `valid_only` is True, does not register a new allocation when all points are NaN.v If `nonzero_only` is True, does not register a new allocation when all points are zero. If `unique_only` is True, does not register a new allocation when it's the same as the last one.""" out = np.empty(len(filled_allocations), dtype=int_) k = 0 for i in range(filled_allocations.shape[0]): all_nan = True all_zeros = True all_same = True for col in range(filled_allocations.shape[1]): if not np.isnan(filled_allocations[i, col]): all_nan = False if abs(filled_allocations[i, col]) > 0: all_zeros = False if k == 0 or (k > 0 and filled_allocations[i, col] != filled_allocations[out[k - 1], col]): all_same = False if valid_only and all_nan: continue if nonzero_only and all_zeros: continue if unique_only and all_same: continue out[k] = i k += 1 return out[:k] @register_chunkable( size=ch.ArraySizer(arg_query="range_starts", axis=0), arg_take_spec=dict( n_cols=None, range_starts=ch.ArraySlicer(axis=0), range_ends=ch.ArraySlicer(axis=0), optimize_func_nb=None, args=ch.ArgsTaker(), ), merge_func="row_stack", ) @register_jitted(tags={"can_parallel"}) def optimize_meta_nb( n_cols: int, range_starts: tp.Array1d, range_ends: tp.Array1d, optimize_func_nb: tp.Callable, *args, ) -> tp.Array2d: """Optimize by reducing each index range. `reduce_func_nb` must take the range index, the range start, the range end, and `*args`. Must return a 1-dim array with the same size as `n_cols`.""" out = np.empty((range_starts.shape[0], n_cols), dtype=float_) for i in prange(len(range_starts)): out[i] = optimize_func_nb(i, range_starts[i], range_ends[i], *args) return out @register_chunkable( size=ch.ArraySizer(arg_query="index_points", axis=0), arg_take_spec=dict( n_cols=None, index_points=ch.ArraySlicer(axis=0), allocate_func_nb=None, args=ch.ArgsTaker(), ), merge_func="row_stack", ) @register_jitted(tags={"can_parallel"}) def allocate_meta_nb( n_cols: int, index_points: tp.Array1d, allocate_func_nb: tp.Callable, *args, ) -> tp.Array2d: """Allocate by mapping each index point. `map_func_nb` must take the point index, the index point, and `*args`. Must return a 1-dim array with the same size as `n_cols`.""" out = np.empty((index_points.shape[0], n_cols), dtype=float_) for i in prange(len(index_points)): out[i] = allocate_func_nb(i, index_points[i], *args) return out @register_jitted(cache=True) def pick_idx_allocate_func_nb(i: int, index_point: int, allocations: tp.Array2d) -> tp.Array1d: """Pick the allocation at an absolute position in an array.""" return allocations[i] @register_jitted(cache=True) def pick_point_allocate_func_nb(i: int, index_point: int, allocations: tp.Array2d) -> tp.Array1d: """Pick the allocation at an index point in an array.""" return allocations[index_point] @register_jitted(cache=True) def random_allocate_func_nb( i: int, index_point: int, n_cols: int, direction: int = Direction.LongOnly, n: tp.Optional[int] = None, ) -> tp.Array1d: """Generate a random allocation.""" weights = np.full(n_cols, np.nan, dtype=float_) pos_sum = 0 neg_sum = 0 if n is None: for c in range(n_cols): w = np.random.uniform(0, 1) if direction == Direction.ShortOnly: w = -w elif direction == Direction.Both: if np.random.randint(0, 2) == 0: w = -w if w >= 0: pos_sum += w else: neg_sum += abs(w) weights[c] = w else: rand_indices = np.random.choice(n_cols, size=n, replace=False) for k in range(len(rand_indices)): w = np.random.uniform(0, 1) if direction == Direction.ShortOnly: w = -w elif direction == Direction.Both: if np.random.randint(0, 2) == 0: w = -w if w >= 0: pos_sum += w else: neg_sum += abs(w) weights[rand_indices[k]] = w for c in range(n_cols): if not np.isnan(weights[c]): if weights[c] >= 0: if pos_sum > 0: weights[c] = weights[c] / pos_sum else: if neg_sum > 0: weights[c] = weights[c] / neg_sum else: weights[c] = 0.0 return weights @register_jitted(cache=True) def prepare_alloc_points_nb( index_points: tp.Array1d, allocations: tp.Array2d, group: int, ) -> tp.Tuple[tp.RecordArray, tp.Array2d]: """Prepare allocation points.""" alloc_points = np.empty_like(index_points, dtype=alloc_point_dt) new_allocations = np.empty_like(allocations) k = 0 for i in range(allocations.shape[0]): all_nan = True for col in range(allocations.shape[1]): if not np.isnan(allocations[i, col]): all_nan = False break if all_nan: continue if k > 0 and alloc_points["alloc_idx"][k - 1] == index_points[i]: new_allocations[k - 1] = allocations[i] else: alloc_points["id"][k] = k alloc_points["col"][k] = group alloc_points["alloc_idx"][k] = index_points[i] new_allocations[k] = allocations[i] k += 1 return alloc_points[:k], new_allocations[:k] @register_jitted(cache=True) def prepare_alloc_ranges_nb( start_idx: tp.Array1d, end_idx: tp.Array1d, alloc_idx: tp.Array1d, status: tp.Array1d, allocations: tp.Array2d, group: int, ) -> tp.Tuple[tp.RecordArray, tp.Array2d]: """Prepare allocation ranges.""" alloc_ranges = np.empty_like(alloc_idx, dtype=alloc_range_dt) new_allocations = np.empty_like(allocations) k = 0 for i in range(allocations.shape[0]): all_nan = True for col in range(allocations.shape[1]): if not np.isnan(allocations[i, col]): all_nan = False break if all_nan: continue if k > 0 and alloc_ranges["alloc_idx"][k - 1] == alloc_idx[i]: new_allocations[k - 1] = allocations[i] else: alloc_ranges["id"][k] = k alloc_ranges["col"][k] = group alloc_ranges["start_idx"][k] = start_idx[i] alloc_ranges["end_idx"][k] = end_idx[i] alloc_ranges["alloc_idx"][k] = alloc_idx[i] alloc_ranges["status"][k] = status[i] new_allocations[k] = allocations[i] k += 1 return alloc_ranges[:k], new_allocations[:k] @register_jitted(cache=True) def rescale_allocations_nb(allocations: tp.Array2d, to_range: tp.Tuple[float, float]) -> tp.Array2d: """Rescale allocations to a new scale. Positive and negative weights are rescaled separately from each other.""" new_min, new_max = to_range if np.isnan(new_min) or np.isinf(new_min): raise ValueError("Minimum of the new scale must be finite") if np.isnan(new_max) or np.isinf(new_max): raise ValueError("Maximum of the new scale must be finite") if new_min >= new_max: raise ValueError("Minimum cannot be equal to or higher than maximum") out = np.empty_like(allocations, dtype=float_) for i in range(allocations.shape[0]): all_nan = True all_zero = True pos_sum = 0.0 neg_sum = 0.0 for col in range(allocations.shape[1]): if np.isnan(allocations[i, col]): continue all_nan = False if allocations[i, col] > 0: all_zero = False pos_sum += allocations[i, col] elif allocations[i, col] < 0: all_zero = False neg_sum += abs(allocations[i, col]) if all_nan: out[i] = np.nan continue if all_zero: out[i] = 0.0 continue if new_max <= 0 and pos_sum > 0: raise ValueError("Cannot rescale positive weights to a negative scale") if new_min >= 0 and neg_sum > 0: raise ValueError("Cannot rescale negative weights to a positive scale") for col in range(allocations.shape[1]): if np.isnan(allocations[i, col]): out[i, col] = np.nan continue if allocations[i, col] > 0: out[i, col] = rescale_nb(allocations[i, col] / pos_sum, (0.0, 1.0), (max(0.0, new_min), new_max)) elif allocations[i, col] < 0: out[i, col] = rescale_nb(abs(allocations[i, col]) / neg_sum, (0.0, 1.0), (min(new_max, 0.0), new_min)) else: out[i, col] = 0.0 return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for working with allocation records.""" from vectorbtpro import _typing as tp from vectorbtpro.generic.ranges import Ranges from vectorbtpro.portfolio.enums import alloc_range_dt, alloc_point_dt from vectorbtpro.records.base import Records from vectorbtpro.records.decorators import override_field_config from vectorbtpro.utils.config import ReadonlyConfig, Config __all__ = [ "AllocRanges", "AllocPoints", ] __pdoc__ = {} # ############# AllocRanges ############# # alloc_ranges_field_config = ReadonlyConfig( dict( dtype=alloc_range_dt, settings={ "idx": dict(name="alloc_idx"), # remap field of Records "col": dict(title="Group", mapping="groups", group_indexing=True), # remap field of Records "alloc_idx": dict(title="Allocation Index", mapping="index"), }, ) ) """_""" __pdoc__[ "alloc_ranges_field_config" ] = f"""Field config for `AllocRanges`. ```python {alloc_ranges_field_config.prettify()} ``` """ AllocRangesT = tp.TypeVar("AllocRangesT", bound="AllocRanges") @override_field_config(alloc_ranges_field_config) class AllocRanges(Ranges): """Extends `vectorbtpro.records.base.Records` for working with allocation point records.""" @property def field_config(self) -> Config: return self._field_config AllocRanges.override_field_config_doc(__pdoc__) # ############# AllocPoints ############# # alloc_points_field_config = ReadonlyConfig( dict( dtype=alloc_point_dt, settings={ "idx": dict(name="alloc_idx"), # remap field of Records "col": dict(title="Group", mapping="groups", group_indexing=True), # remap field of Records "alloc_idx": dict(title="Allocation Index", mapping="index"), }, ) ) """_""" __pdoc__[ "alloc_points_field_config" ] = f"""Field config for `AllocRanges`. ```python {alloc_points_field_config.prettify()} ``` """ AllocPointsT = tp.TypeVar("AllocPointsT", bound="AllocPoints") @override_field_config(alloc_points_field_config) class AllocPoints(Records): """Extends `vectorbtpro.generic.ranges.Ranges` for working with allocation range records.""" @property def field_config(self) -> Config: return self._field_config AllocPoints.override_field_config_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with portfolio.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.portfolio.nb import * from vectorbtpro.portfolio.pfopt import * from vectorbtpro.portfolio.base import * from vectorbtpro.portfolio.call_seq import * from vectorbtpro.portfolio.chunking import * from vectorbtpro.portfolio.decorators import * from vectorbtpro.portfolio.logs import * from vectorbtpro.portfolio.orders import * from vectorbtpro.portfolio.preparing import * from vectorbtpro.portfolio.trades import * __exclude_from__all__ = [ "enums", ] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for simulating a portfolio and measuring its performance.""" import inspect import string from functools import partial import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.indexes import ExceptLevel from vectorbtpro.base.merging import row_stack_arrays from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.reshaping import ( to_1d_array, to_2d_array, broadcast_array_to, to_pd_array, to_2d_shape, ) from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.data.base import OHLCDataMixin from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.generic.drawdowns import Drawdowns from vectorbtpro.generic.sim_range import SimRangeMixin from vectorbtpro.portfolio import nb, enums from vectorbtpro.portfolio.decorators import attach_shortcut_properties, attach_returns_acc_methods from vectorbtpro.portfolio.logs import Logs from vectorbtpro.portfolio.orders import Orders, FSOrders from vectorbtpro.portfolio.pfopt.base import PortfolioOptimizer from vectorbtpro.portfolio.preparing import ( PFPrepResult, BasePFPreparer, FOPreparer, FSPreparer, FOFPreparer, FDOFPreparer, ) from vectorbtpro.portfolio.trades import Trades, EntryTrades, ExitTrades, Positions from vectorbtpro.records.base import Records from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.returns.accessors import ReturnsAccessor from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import get_dict_attr from vectorbtpro.utils.base import Base from vectorbtpro.utils.colors import adjust_opacity from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, ReadonlyConfig, HybridConfig, atomic_dict from vectorbtpro.utils.decorators import custom_property, cached_property, hybrid_method from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.parsing import get_func_kwargs from vectorbtpro.utils.template import Rep, RepEval, RepFunc from vectorbtpro.utils.warnings_ import warn try: if not tp.TYPE_CHECKING: raise ImportError from vectorbtpro.returns.qs_adapter import QSAdapter as QSAdapterT except ImportError: QSAdapterT = "QSAdapter" __all__ = [ "Portfolio", "PF", ] __pdoc__ = {} def fix_wrapper_for_records(pf: "Portfolio") -> ArrayWrapper: """Allow flags for records that were restricted for portfolio.""" if pf.cash_sharing: return pf.wrapper.replace(allow_enable=True, allow_modify=True) return pf.wrapper def records_indexing_func( self: "Portfolio", obj: tp.RecordArray, wrapper_meta: dict, cls: tp.Union[type, str], groups_only: bool = False, **kwargs, ) -> tp.RecordArray: """Apply indexing function on records.""" wrapper = fix_wrapper_for_records(self) if groups_only: wrapper = wrapper.resolve() wrapper_meta = dict(wrapper_meta) wrapper_meta["col_idxs"] = wrapper_meta["group_idxs"] if isinstance(cls, str): cls = getattr(self, cls) records = cls(wrapper, obj) records_meta = records.indexing_func_meta(wrapper_meta=wrapper_meta) return records.indexing_func(records_meta=records_meta).values def records_resample_func( self: "Portfolio", obj: tp.ArrayLike, resampler: tp.Union[Resampler, tp.PandasResampler], wrapper: ArrayWrapper, cls: tp.Union[type, str], **kwargs, ) -> tp.RecordArray: """Apply resampling function on records.""" if isinstance(cls, str): cls = getattr(self, cls) return cls(wrapper, obj).resample(resampler).values def returns_resample_func( self: "Portfolio", obj: tp.ArrayLike, resampler: tp.Union[Resampler, tp.PandasResampler], wrapper: ArrayWrapper, fill_with_zero: bool = True, log_returns: bool = False, **kwargs, ): """Apply resampling function on returns.""" return ( pd.DataFrame(obj, index=wrapper.index) .vbt.returns(log_returns=log_returns) .resample( resampler, fill_with_zero=fill_with_zero, ) .obj.values ) returns_acc_config = ReadonlyConfig( { "daily_returns": dict(source_name="daily"), "annual_returns": dict(source_name="annual"), "cumulative_returns": dict(source_name="cumulative"), "annualized_return": dict(source_name="annualized"), "annualized_volatility": dict(), "calmar_ratio": dict(), "omega_ratio": dict(), "sharpe_ratio": dict(), "sharpe_ratio_std": dict(), "prob_sharpe_ratio": dict(), "deflated_sharpe_ratio": dict(), "downside_risk": dict(), "sortino_ratio": dict(), "information_ratio": dict(), "beta": dict(), "alpha": dict(), "tail_ratio": dict(), "value_at_risk": dict(), "cond_value_at_risk": dict(), "capture_ratio": dict(), "up_capture_ratio": dict(), "down_capture_ratio": dict(), "drawdown": dict(), "max_drawdown": dict(), } ) """_""" __pdoc__[ "returns_acc_config" ] = f"""Config of returns accessor methods to be attached to `Portfolio`. ```python {returns_acc_config.prettify()} ``` """ shortcut_config = ReadonlyConfig( { "filled_close": dict(group_by_aware=False, decorator=cached_property), "filled_bm_close": dict(group_by_aware=False, decorator=cached_property), "weights": dict(group_by_aware=False, decorator=cached_property, obj_type="red_array"), "long_view": dict(obj_type="portfolio"), "short_view": dict(obj_type="portfolio"), "orders": dict( obj_type="records", field_aliases=("order_records",), wrap_func=lambda pf, obj, **kwargs: pf.orders_cls.from_records( fix_wrapper_for_records(pf), obj, open=pf.open_flex, high=pf.high_flex, low=pf.low_flex, close=pf.close_flex, ), indexing_func=partial(records_indexing_func, cls="orders_cls"), resample_func=partial(records_resample_func, cls="orders_cls"), ), "logs": dict( obj_type="records", field_aliases=("log_records",), wrap_func=lambda pf, obj, **kwargs: pf.logs_cls.from_records( fix_wrapper_for_records(pf), obj, open=pf.open_flex, high=pf.high_flex, low=pf.low_flex, close=pf.close_flex, ), indexing_func=partial(records_indexing_func, cls="logs_cls"), resample_func=partial(records_resample_func, cls="logs_cls"), ), "entry_trades": dict( obj_type="records", field_aliases=("entry_trade_records",), wrap_func=lambda pf, obj, **kwargs: pf.entry_trades_cls.from_records( fix_wrapper_for_records(pf), obj, open=pf.open_flex, high=pf.high_flex, low=pf.low_flex, close=pf.close_flex, ), indexing_func=partial(records_indexing_func, cls="entry_trades_cls"), resample_func=partial(records_resample_func, cls="entry_trades_cls"), ), "exit_trades": dict( obj_type="records", field_aliases=("exit_trade_records",), wrap_func=lambda pf, obj, **kwargs: pf.exit_trades_cls.from_records( fix_wrapper_for_records(pf), obj, open=pf.open_flex, high=pf.high_flex, low=pf.low_flex, close=pf.close_flex, ), indexing_func=partial(records_indexing_func, cls="exit_trades_cls"), resample_func=partial(records_resample_func, cls="exit_trades_cls"), ), "trades": dict( obj_type="records", field_aliases=("trade_records",), wrap_func=lambda pf, obj, **kwargs: pf.trades_cls.from_records( fix_wrapper_for_records(pf), obj, open=pf.open_flex, high=pf.high_flex, low=pf.low_flex, close=pf.close_flex, ), indexing_func=partial(records_indexing_func, cls="trades_cls"), resample_func=partial(records_resample_func, cls="trades_cls"), ), "trade_history": dict(), "signals": dict(), "positions": dict( obj_type="records", field_aliases=("position_records",), wrap_func=lambda pf, obj, **kwargs: pf.positions_cls.from_records( fix_wrapper_for_records(pf), obj, open=pf.open_flex, high=pf.high_flex, low=pf.low_flex, close=pf.close_flex, ), indexing_func=partial(records_indexing_func, cls="positions_cls"), resample_func=partial(records_resample_func, cls="positions_cls"), ), "drawdowns": dict( obj_type="records", field_aliases=("drawdown_records",), wrap_func=lambda pf, obj, **kwargs: pf.drawdowns_cls.from_records( fix_wrapper_for_records(pf).resolve(), obj, ), indexing_func=partial(records_indexing_func, cls="drawdowns_cls", groups_only=True), resample_func=partial(records_resample_func, cls="drawdowns_cls"), ), "init_position": dict(obj_type="red_array", group_by_aware=False), "asset_flow": dict( group_by_aware=False, resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)), ), "long_asset_flow": dict( method_name="get_asset_flow", group_by_aware=False, method_kwargs=dict(direction="longonly"), resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)), ), "short_asset_flow": dict( method_name="get_asset_flow", group_by_aware=False, method_kwargs=dict(direction="shortonly"), resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)), ), "assets": dict(group_by_aware=False), "long_assets": dict( method_name="get_assets", group_by_aware=False, method_kwargs=dict(direction="longonly"), ), "short_assets": dict( method_name="get_assets", group_by_aware=False, method_kwargs=dict(direction="shortonly"), ), "position_mask": dict(), "long_position_mask": dict(method_name="get_position_mask", method_kwargs=dict(direction="longonly")), "short_position_mask": dict(method_name="get_position_mask", method_kwargs=dict(direction="shortonly")), "position_coverage": dict(obj_type="red_array"), "long_position_coverage": dict( method_name="get_position_coverage", obj_type="red_array", method_kwargs=dict(direction="longonly"), ), "short_position_coverage": dict( method_name="get_position_coverage", obj_type="red_array", method_kwargs=dict(direction="shortonly"), ), "position_entry_price": dict(group_by_aware=False), "position_exit_price": dict(group_by_aware=False), "init_cash": dict(obj_type="red_array"), "cash_deposits": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))), "total_cash_deposits": dict(obj_type="red_array"), "cash_earnings": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))), "total_cash_earnings": dict(obj_type="red_array"), "cash_flow": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))), "free_cash_flow": dict( method_name="get_cash_flow", method_kwargs=dict(free=True), resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)), ), "cash": dict(), "position": dict(method_name="get_assets", group_by_aware=False), "debt": dict(method_name=None, group_by_aware=False), "locked_cash": dict(method_name=None, group_by_aware=False), "free_cash": dict(method_name="get_cash", method_kwargs=dict(free=True)), "init_price": dict(obj_type="red_array", group_by_aware=False), "init_position_value": dict(obj_type="red_array"), "init_value": dict(obj_type="red_array"), "input_value": dict(obj_type="red_array"), "asset_value": dict(), "long_asset_value": dict(method_name="get_asset_value", method_kwargs=dict(direction="longonly")), "short_asset_value": dict(method_name="get_asset_value", method_kwargs=dict(direction="shortonly")), "gross_exposure": dict(), "long_gross_exposure": dict(method_name="get_gross_exposure", method_kwargs=dict(direction="longonly")), "short_gross_exposure": dict(method_name="get_gross_exposure", method_kwargs=dict(direction="shortonly")), "net_exposure": dict(), "value": dict(), "allocations": dict(group_by_aware=False), "long_allocations": dict( method_name="get_allocations", method_kwargs=dict(direction="longonly"), group_by_aware=False, ), "short_allocations": dict( method_name="get_allocations", method_kwargs=dict(direction="shortonly"), group_by_aware=False, ), "total_profit": dict(obj_type="red_array"), "final_value": dict(obj_type="red_array"), "total_return": dict(obj_type="red_array"), "returns": dict(resample_func=returns_resample_func), "log_returns": dict( method_name="get_returns", method_kwargs=dict(log_returns=True), resample_func=partial(returns_resample_func, log_returns=True), ), "daily_log_returns": dict( method_name="get_returns", method_kwargs=dict(daily_returns=True, log_returns=True), resample_func=partial(returns_resample_func, log_returns=True), ), "asset_pnl": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))), "asset_returns": dict(resample_func=returns_resample_func), "market_value": dict(), "market_returns": dict(resample_func=returns_resample_func), "bm_value": dict(), "bm_returns": dict(resample_func=returns_resample_func), "total_market_return": dict(obj_type="red_array"), "daily_returns": dict(resample_func=returns_resample_func), "annual_returns": dict(resample_func=returns_resample_func), "cumulative_returns": dict(), "annualized_return": dict(obj_type="red_array"), "annualized_volatility": dict(obj_type="red_array"), "calmar_ratio": dict(obj_type="red_array"), "omega_ratio": dict(obj_type="red_array"), "sharpe_ratio": dict(obj_type="red_array"), "sharpe_ratio_std": dict(obj_type="red_array"), "prob_sharpe_ratio": dict(obj_type="red_array"), "deflated_sharpe_ratio": dict(obj_type="red_array"), "downside_risk": dict(obj_type="red_array"), "sortino_ratio": dict(obj_type="red_array"), "information_ratio": dict(obj_type="red_array"), "beta": dict(obj_type="red_array"), "alpha": dict(obj_type="red_array"), "tail_ratio": dict(obj_type="red_array"), "value_at_risk": dict(obj_type="red_array"), "cond_value_at_risk": dict(obj_type="red_array"), "capture_ratio": dict(obj_type="red_array"), "up_capture_ratio": dict(obj_type="red_array"), "down_capture_ratio": dict(obj_type="red_array"), "drawdown": dict(), "max_drawdown": dict(obj_type="red_array"), } ) """_""" __pdoc__[ "shortcut_config" ] = f"""Config of shortcut properties to be attached to `Portfolio`. ```python {shortcut_config.prettify()} ``` """ PortfolioT = tp.TypeVar("PortfolioT", bound="Portfolio") PortfolioResultT = tp.Union[PortfolioT, BasePFPreparer, PFPrepResult, enums.SimulationOutput] class MetaPortfolio(type(Analyzable)): """Metaclass for `Portfolio`.""" @property def in_output_config(cls) -> Config: """In-output config.""" return cls._in_output_config @attach_shortcut_properties(shortcut_config) @attach_returns_acc_methods(returns_acc_config) class Portfolio(Analyzable, SimRangeMixin, metaclass=MetaPortfolio): """Class for simulating a portfolio and measuring its performance. Args: wrapper (ArrayWrapper): Array wrapper. See `vectorbtpro.base.wrapping.ArrayWrapper`. close (array_like): Last asset price at each time step. order_records (array_like): A structured NumPy array of order records. open (array_like): Open price of each bar. high (array_like): High price of each bar. low (array_like): Low price of each bar. log_records (array_like): A structured NumPy array of log records. cash_sharing (bool): Whether to share cash within the same group. init_cash (InitCashMode or array_like of float): Initial capital. Can be provided in a format suitable for flexible indexing. init_position (array_like of float): Initial position. Can be provided in a format suitable for flexible indexing. init_price (array_like of float): Initial position price. Can be provided in a format suitable for flexible indexing. cash_deposits (array_like of float): Cash deposited/withdrawn at each timestamp. Can be provided in a format suitable for flexible indexing. cash_earnings (array_like of float): Earnings added at each timestamp. Can be provided in a format suitable for flexible indexing. sim_start (int, datetime_like, or array_like): Simulation start per column. Defaults to None. sim_end (int, datetime_like, or array_like): Simulation end per column. Defaults to None. call_seq (array_like of int): Sequence of calls per row and group. Defaults to None. in_outputs (namedtuple): Named tuple with in-output objects. To substitute `Portfolio` attributes, provide already broadcasted and grouped objects. Also see `Portfolio.in_outputs_indexing_func` on how in-output objects are indexed. use_in_outputs (bool): Whether to return in-output objects when calling properties. bm_close (array_like): Last benchmark asset price at each time step. fillna_close (bool): Whether to forward and backward fill NaN values in `close`. Applied after the simulation to avoid NaNs in asset value. See `Portfolio.get_filled_close`. weights (array_like): Asset weights. Applied to the initial position, initial cash, cash deposits, cash earnings, and orders. trades_type (str or int): Default `vectorbtpro.portfolio.trades.Trades` to use across `Portfolio`. See `vectorbtpro.portfolio.enums.TradesType`. orders_cls (type): Class for wrapping order records. logs_cls (type): Class for wrapping log records. trades_cls (type): Class for wrapping trade records. entry_trades_cls (type): Class for wrapping entry trade records. exit_trades_cls (type): Class for wrapping exit trade records. positions_cls (type): Class for wrapping position records. drawdowns_cls (type): Class for wrapping drawdown records. For defaults, see `vectorbtpro._settings.portfolio`. !!! note Use class methods with `from_` prefix to build a portfolio. The `__init__` method is reserved for indexing purposes. !!! note This class is meant to be immutable. To change any attribute, use `Portfolio.replace`.""" _writeable_attrs: tp.WriteableAttrs = {"_in_output_config"} @classmethod def row_stack_objs( cls: tp.Type[PortfolioT], objs: tp.Sequence[tp.Any], wrappers: tp.Sequence[ArrayWrapper], grouping: str = "columns_or_groups", obj_name: tp.Optional[str] = None, obj_type: tp.Optional[str] = None, wrapper: tp.Optional[ArrayWrapper] = None, cash_sharing: bool = False, row_stack_func: tp.Optional[tp.Callable] = None, **kwargs, ) -> tp.Any: """Stack (two-dimensional) objects along rows. `row_stack_func` must take the portfolio class, and all the arguments passed to this method. If you don't need any of the arguments, make `row_stack_func` accept them as `**kwargs`. If all the objects are None, boolean, or empty, returns the first one.""" if len(objs) == 1: objs = objs[0] objs = list(objs) all_none = True for obj in objs: if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0): if not checks.is_deep_equal(obj, objs[0]): raise ValueError(f"Cannot unify scalar in-outputs with the name '{obj_name}'") else: all_none = False break if all_none: return objs[0] if row_stack_func is not None: return row_stack_func( cls, objs, wrappers, grouping=grouping, obj_name=obj_name, obj_type=obj_type, wrapper=wrapper, **kwargs, ) if grouping == "columns_or_groups": obj_group_by = None elif grouping == "columns": obj_group_by = False elif grouping == "groups": obj_group_by = None elif grouping == "cash_sharing": obj_group_by = None if cash_sharing else False else: raise ValueError(f"Grouping '{grouping}' is not supported") if obj_type is None and checks.is_np_array(objs[0]): n_cols = wrapper.get_shape_2d(group_by=obj_group_by)[1] can_stack = (objs[0].ndim == 1 and n_cols == 1) or (objs[0].ndim == 2 and objs[0].shape[1] == n_cols) elif obj_type is not None and obj_type == "array": can_stack = True else: can_stack = False if can_stack: wrapped_objs = [] for i, obj in enumerate(objs): wrapped_objs.append(wrappers[i].wrap(obj, group_by=obj_group_by)) return wrapper.row_stack_arrs(*wrapped_objs, group_by=obj_group_by, wrap=False) raise ValueError(f"Cannot figure out how to stack in-outputs with the name '{obj_name}' along rows") @classmethod def row_stack_in_outputs( cls: tp.Type[PortfolioT], *objs: tp.MaybeTuple[PortfolioT], **kwargs, ) -> tp.Optional[tp.NamedTuple]: """Stack `Portfolio.in_outputs` along rows. All in-output tuples must be either None or have the same fields. If the field can be found in the attributes of this `Portfolio` instance, reads the attribute's options to get requirements for the type and layout of the in-output object. For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with `Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`. Performs stacking on the in-output objects of the same field using `Portfolio.row_stack_objs`.""" if len(objs) == 1: objs = objs[0] objs = list(objs) all_none = True for obj in objs: if obj.in_outputs is not None: all_none = False break if all_none: return None all_keys = set() for obj in objs: all_keys |= set(obj.in_outputs._asdict().keys()) for obj in objs: if obj.in_outputs is None or len(all_keys.difference(set(obj.in_outputs._asdict().keys()))) > 0: raise ValueError("Objects to be merged must have the same in-output fields") cls_dir = set(dir(cls)) new_in_outputs = {} for field in objs[0].in_outputs._asdict().keys(): field_options = merge_dicts( cls.parse_field_options(field), cls.in_output_config.get(field, None), ) if field_options.get("field", field) in cls_dir: prop = getattr(cls, field_options["field"]) prop_options = getattr(prop, "options", {}) obj_type = prop_options.get("obj_type", "array") group_by_aware = prop_options.get("group_by_aware", True) row_stack_func = prop_options.get("row_stack_func", None) else: obj_type = None group_by_aware = True row_stack_func = None _kwargs = merge_dicts( dict( grouping=field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"), obj_name=field_options.get("field", field), obj_type=field_options.get("obj_type", obj_type), row_stack_func=field_options.get("row_stack_func", row_stack_func), ), kwargs, ) new_field_obj = cls.row_stack_objs( [getattr(obj.in_outputs, field) for obj in objs], [obj.wrapper for obj in objs], **_kwargs, ) new_in_outputs[field] = new_field_obj return type(objs[0].in_outputs)(**new_in_outputs) @hybrid_method def row_stack( cls_or_self: tp.MaybeType[PortfolioT], *objs: tp.MaybeTuple[PortfolioT], wrapper_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, combine_init_cash: bool = False, combine_init_position: bool = False, combine_init_price: bool = False, **kwargs, ) -> PortfolioT: """Stack multiple `Portfolio` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers. Cash sharing must be the same among all objects. Close, benchmark close, cash deposits, cash earnings, call sequence, and other two-dimensional arrays are stacked using `vectorbtpro.base.wrapping.ArrayWrapper.row_stack_arrs`. In-outputs are stacked using `Portfolio.row_stack_in_outputs`. Records are stacked using `vectorbtpro.records.base.Records.row_stack_records_arrs`. If the initial cash of each object is one of the options in `vectorbtpro.portfolio.enums.InitCashMode`, it will be retained for the resulting object. Once any of the objects has the initial cash listed as an absolute amount or an array, the initial cash of the first object will be copied over to the final object, while the initial cash of all other objects will be resolved and used as cash deposits, unless they all are zero. Set `combine_init_cash` to True to simply sum all initial cash arrays. If only the first object has an initial position greater than zero, it will be copied over to the final object. Otherwise, an error will be thrown, unless `combine_init_position` is enabled to sum all initial position arrays. The same goes for the initial price, which becomes a candidate for stacking only if any of the arrays are not NaN. !!! note When possible, avoid using initial position and price in objects to be stacked: there is currently no way of injecting them in the correct order, while simply taking the sum or weighted average may distort the reality since they weren't available prior to the actual simulation.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Portfolio): raise TypeError("Each object to be merged must be an instance of Portfolio") _objs = list(map(lambda x: x.disable_weights(), objs)) if "wrapper" not in kwargs: wrapper_kwargs = merge_dicts(dict(group_by=group_by), wrapper_kwargs) kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in _objs], **wrapper_kwargs) for i in range(1, len(_objs)): if _objs[i].cash_sharing != _objs[0].cash_sharing: raise ValueError("Objects to be merged must have the same 'cash_sharing'") kwargs["cash_sharing"] = _objs[0].cash_sharing cs_group_by = None if kwargs["cash_sharing"] else False cs_n_cols = kwargs["wrapper"].get_shape_2d(group_by=cs_group_by)[1] n_cols = kwargs["wrapper"].shape_2d[1] if "close" not in kwargs: kwargs["close"] = kwargs["wrapper"].row_stack_arrs( *[obj.close for obj in _objs], group_by=False, wrap=False, ) if "open" not in kwargs: stack_open_objs = True for obj in _objs: if obj._open is None: stack_open_objs = False break if stack_open_objs: kwargs["open"] = kwargs["wrapper"].row_stack_arrs( *[obj.open for obj in _objs], group_by=False, wrap=False, ) if "high" not in kwargs: stack_high_objs = True for obj in _objs: if obj._high is None: stack_high_objs = False break if stack_high_objs: kwargs["high"] = kwargs["wrapper"].row_stack_arrs( *[obj.high for obj in _objs], group_by=False, wrap=False, ) if "low" not in kwargs: stack_low_objs = True for obj in _objs: if obj._low is None: stack_low_objs = False break if stack_low_objs: kwargs["low"] = kwargs["wrapper"].row_stack_arrs( *[obj.low for obj in _objs], group_by=False, wrap=False, ) if "order_records" not in kwargs: kwargs["order_records"] = Orders.row_stack_records_arrs(*[obj.orders for obj in _objs], **kwargs) if "log_records" not in kwargs: kwargs["log_records"] = Logs.row_stack_records_arrs(*[obj.logs for obj in _objs], **kwargs) if "init_cash" not in kwargs: stack_init_cash_objs = False for obj in _objs: if not checks.is_int(obj._init_cash) or obj._init_cash not in enums.InitCashMode: stack_init_cash_objs = True break if stack_init_cash_objs: stack_init_cash_objs = False init_cash_objs = [] for i, obj in enumerate(_objs): init_cash_obj = obj.get_init_cash(group_by=cs_group_by) init_cash_obj = to_1d_array(init_cash_obj) init_cash_obj = broadcast_array_to(init_cash_obj, cs_n_cols) if i > 0 and (init_cash_obj != 0).any(): stack_init_cash_objs = True init_cash_objs.append(init_cash_obj) if stack_init_cash_objs: if not combine_init_cash: cash_deposits_objs = [] for i, obj in enumerate(_objs): cash_deposits_obj = obj.get_cash_deposits(group_by=cs_group_by) cash_deposits_obj = to_2d_array(cash_deposits_obj) cash_deposits_obj = broadcast_array_to( cash_deposits_obj, (cash_deposits_obj.shape[0], cs_n_cols), ) cash_deposits_obj = cash_deposits_obj.copy() if i > 0: cash_deposits_obj[0] = init_cash_objs[i] cash_deposits_objs.append(cash_deposits_obj) kwargs["cash_deposits"] = row_stack_arrays(cash_deposits_objs) kwargs["init_cash"] = init_cash_objs[0] else: kwargs["init_cash"] = np.asarray(init_cash_objs).sum(axis=0) else: kwargs["init_cash"] = init_cash_objs[0] if "init_position" not in kwargs: stack_init_position_objs = False init_position_objs = [] for i, obj in enumerate(_objs): init_position_obj = obj.get_init_position() init_position_obj = to_1d_array(init_position_obj) init_position_obj = broadcast_array_to(init_position_obj, n_cols) if i > 0 and (init_position_obj != 0).any(): stack_init_position_objs = True init_position_objs.append(init_position_obj) if stack_init_position_objs: if not combine_init_position: raise ValueError("Initial position cannot be stacked along rows") kwargs["init_position"] = np.asarray(init_position_objs).sum(axis=0) else: kwargs["init_position"] = init_position_objs[0] if "init_price" not in kwargs: stack_init_price_objs = False init_position_objs = [] init_price_objs = [] for i, obj in enumerate(_objs): init_position_obj = obj.get_init_position() init_position_obj = to_1d_array(init_position_obj) init_position_obj = broadcast_array_to(init_position_obj, n_cols) init_price_obj = obj.get_init_price() init_price_obj = to_1d_array(init_price_obj) init_price_obj = broadcast_array_to(init_price_obj, n_cols) if i > 0 and (init_position_obj != 0).any() and not np.isnan(init_price_obj).all(): stack_init_price_objs = True init_position_objs.append(init_position_obj) init_price_objs.append(init_price_obj) if stack_init_price_objs: if not combine_init_price: raise ValueError("Initial price cannot be stacked along rows") init_position_objs = np.asarray(init_position_objs) init_price_objs = np.asarray(init_price_objs) mask1 = (init_position_objs != 0).any(axis=1) mask2 = (~np.isnan(init_price_objs)).any(axis=1) mask = mask1 & mask2 init_position_objs = init_position_objs[mask] init_price_objs = init_price_objs[mask] nom = (init_position_objs * init_price_objs).sum(axis=0) denum = init_position_objs.sum(axis=0) kwargs["init_price"] = nom / denum else: kwargs["init_price"] = init_price_objs[0] if "cash_deposits" not in kwargs: stack_cash_deposits_objs = False for obj in _objs: if obj._cash_deposits.size > 1 or obj._cash_deposits.item() != 0: stack_cash_deposits_objs = True break if stack_cash_deposits_objs: kwargs["cash_deposits"] = kwargs["wrapper"].row_stack_arrs( *[obj.get_cash_deposits(group_by=cs_group_by) for obj in _objs], group_by=cs_group_by, wrap=False, ) else: kwargs["cash_deposits"] = np.array([[0.0]]) if "cash_earnings" not in kwargs: stack_cash_earnings_objs = False for obj in _objs: if obj._cash_earnings.size > 1 or obj._cash_earnings.item() != 0: stack_cash_earnings_objs = True break if stack_cash_earnings_objs: kwargs["cash_earnings"] = kwargs["wrapper"].row_stack_arrs( *[obj.get_cash_earnings(group_by=False) for obj in _objs], group_by=False, wrap=False, ) else: kwargs["cash_earnings"] = np.array([[0.0]]) if "call_seq" not in kwargs: stack_call_seq_objs = True for obj in _objs: if obj.config["call_seq"] is None: stack_call_seq_objs = False break if stack_call_seq_objs: kwargs["call_seq"] = kwargs["wrapper"].row_stack_arrs( *[obj.call_seq for obj in _objs], group_by=False, wrap=False, ) if "bm_close" not in kwargs: stack_bm_close_objs = True for obj in _objs: if obj._bm_close is None or isinstance(obj._bm_close, bool): stack_bm_close_objs = False break if stack_bm_close_objs: kwargs["bm_close"] = kwargs["wrapper"].row_stack_arrs( *[obj.bm_close for obj in _objs], group_by=False, wrap=False, ) if "in_outputs" not in kwargs: kwargs["in_outputs"] = cls.row_stack_in_outputs(*_objs, **kwargs) if "sim_start" not in kwargs: kwargs["sim_start"] = cls.row_stack_sim_start(kwargs["wrapper"], *_objs) if "sim_end" not in kwargs: kwargs["sim_end"] = cls.row_stack_sim_end(kwargs["wrapper"], *_objs) kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @classmethod def column_stack_objs( cls: tp.Type[PortfolioT], objs: tp.Sequence[tp.Any], wrappers: tp.Sequence[ArrayWrapper], grouping: str = "columns_or_groups", obj_name: tp.Optional[str] = None, obj_type: tp.Optional[str] = None, wrapper: tp.Optional[ArrayWrapper] = None, cash_sharing: bool = False, column_stack_func: tp.Optional[tp.Callable] = None, **kwargs, ) -> tp.Any: """Stack (one and two-dimensional) objects along column. `column_stack_func` must take the portfolio class, and all the arguments passed to this method. If you don't need any of the arguments, make `column_stack_func` accept them as `**kwargs`. If all the objects are None, boolean, or empty, returns the first one.""" if len(objs) == 1: objs = objs[0] objs = list(objs) all_none = True for obj in objs: if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0): if not checks.is_deep_equal(obj, objs[0]): raise ValueError(f"Cannot unify scalar in-outputs with the name '{obj_name}'") else: all_none = False break if all_none: return objs[0] if column_stack_func is not None: return column_stack_func( cls, objs, wrappers, grouping=grouping, obj_name=obj_name, obj_type=obj_type, wrapper=wrapper, **kwargs, ) if grouping == "columns_or_groups": obj_group_by = None elif grouping == "columns": obj_group_by = False elif grouping == "groups": obj_group_by = None elif grouping == "cash_sharing": obj_group_by = None if cash_sharing else False else: raise ValueError(f"Grouping '{grouping}' is not supported") if obj_type is None and checks.is_np_array(obj): if to_2d_shape(objs[0].shape) == wrappers[0].get_shape_2d(group_by=obj_group_by): can_stack = True reduced = False elif objs[0].shape == (wrappers[0].get_shape_2d(group_by=obj_group_by)[1],): can_stack = True reduced = True else: can_stack = False elif obj_type is not None and obj_type == "array": can_stack = True reduced = False elif obj_type is not None and obj_type == "red_array": can_stack = True reduced = True else: can_stack = False if can_stack: if reduced: wrapped_objs = [] for i, obj in enumerate(objs): wrapped_objs.append(wrappers[i].wrap_reduced(obj, group_by=obj_group_by)) return wrapper.concat_arrs(*wrapped_objs, group_by=obj_group_by).values wrapped_objs = [] for i, obj in enumerate(objs): wrapped_objs.append(wrappers[i].wrap(obj, group_by=obj_group_by)) return wrapper.column_stack_arrs(*wrapped_objs, group_by=obj_group_by, wrap=False) raise ValueError(f"Cannot figure out how to stack in-outputs with the name '{obj_name}' along columns") @classmethod def column_stack_in_outputs( cls: tp.Type[PortfolioT], *objs: tp.MaybeTuple[PortfolioT], **kwargs, ) -> tp.Optional[tp.NamedTuple]: """Stack `Portfolio.in_outputs` along columns. All in-output tuples must be either None or have the same fields. If the field can be found in the attributes of this `Portfolio` instance, reads the attribute's options to get requirements for the type and layout of the in-output object. For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with `Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`. Performs stacking on the in-output objects of the same field using `Portfolio.column_stack_objs`.""" if len(objs) == 1: objs = objs[0] objs = list(objs) all_none = True for obj in objs: if obj.in_outputs is not None: all_none = False break if all_none: return None all_keys = set() for obj in objs: all_keys |= set(obj.in_outputs._asdict().keys()) for obj in objs: if obj.in_outputs is None or len(all_keys.difference(set(obj.in_outputs._asdict().keys()))) > 0: raise ValueError("Objects to be merged must have the same in-output fields") cls_dir = set(dir(cls)) new_in_outputs = {} for field in objs[0].in_outputs._asdict().keys(): field_options = merge_dicts( cls.parse_field_options(field), cls.in_output_config.get(field, None), ) if field_options.get("field", field) in cls_dir: prop = getattr(cls, field_options["field"]) prop_options = getattr(prop, "options", {}) obj_type = prop_options.get("obj_type", "array") group_by_aware = prop_options.get("group_by_aware", True) column_stack_func = prop_options.get("column_stack_func", None) else: obj_type = None group_by_aware = True column_stack_func = None _kwargs = merge_dicts( dict( grouping=field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"), obj_name=field_options.get("field", field), obj_type=field_options.get("obj_type", obj_type), column_stack_func=field_options.get("column_stack_func", column_stack_func), ), kwargs, ) new_field_obj = cls.column_stack_objs( [getattr(obj.in_outputs, field) for obj in objs], [obj.wrapper for obj in objs], **_kwargs, ) new_in_outputs[field] = new_field_obj return type(objs[0].in_outputs)(**new_in_outputs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[PortfolioT], *objs: tp.MaybeTuple[PortfolioT], wrapper_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, ffill_close: bool = False, fbfill_close: bool = False, **kwargs, ) -> PortfolioT: """Stack multiple `Portfolio` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers. Cash sharing must be the same among all objects. Two-dimensional arrays are stacked using `vectorbtpro.base.wrapping.ArrayWrapper.column_stack_arrs` while one-dimensional arrays are stacked using `vectorbtpro.base.wrapping.ArrayWrapper.concat_arrs`. In-outputs are stacked using `Portfolio.column_stack_in_outputs`. Records are stacked using `vectorbtpro.records.base.Records.column_stack_records_arrs`.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Portfolio): raise TypeError("Each object to be merged must be an instance of Portfolio") _objs = list(map(lambda x: x.disable_weights(), objs)) if "wrapper" not in kwargs: wrapper_kwargs = merge_dicts(dict(group_by=group_by), wrapper_kwargs) kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in _objs], **wrapper_kwargs, ) for i in range(1, len(_objs)): if _objs[i].cash_sharing != _objs[0].cash_sharing: raise ValueError("Objects to be merged must have the same 'cash_sharing'") if "cash_sharing" not in kwargs: kwargs["cash_sharing"] = _objs[0].cash_sharing cs_group_by = None if kwargs["cash_sharing"] else False if "close" not in kwargs: new_close = kwargs["wrapper"].column_stack_arrs( *[obj.close for obj in _objs], group_by=False, ) if fbfill_close: new_close = new_close.vbt.fbfill() elif ffill_close: new_close = new_close.vbt.ffill() kwargs["close"] = new_close if "open" not in kwargs: stack_open_objs = True for obj in _objs: if obj._open is None: stack_open_objs = False break if stack_open_objs: kwargs["open"] = kwargs["wrapper"].column_stack_arrs( *[obj.open for obj in _objs], group_by=False, wrap=False, ) if "high" not in kwargs: stack_high_objs = True for obj in _objs: if obj._high is None: stack_high_objs = False break if stack_high_objs: kwargs["high"] = kwargs["wrapper"].column_stack_arrs( *[obj.high for obj in _objs], group_by=False, wrap=False, ) if "low" not in kwargs: stack_low_objs = True for obj in _objs: if obj._low is None: stack_low_objs = False break if stack_low_objs: kwargs["low"] = kwargs["wrapper"].column_stack_arrs( *[obj.low for obj in _objs], group_by=False, wrap=False, ) if "order_records" not in kwargs: kwargs["order_records"] = Orders.column_stack_records_arrs(*[obj.orders for obj in _objs], **kwargs) if "log_records" not in kwargs: kwargs["log_records"] = Logs.column_stack_records_arrs(*[obj.logs for obj in _objs], **kwargs) if "init_cash" not in kwargs: stack_init_cash_objs = False for obj in _objs: if not checks.is_int(obj._init_cash) or obj._init_cash not in enums.InitCashMode: stack_init_cash_objs = True break if stack_init_cash_objs: kwargs["init_cash"] = to_1d_array( kwargs["wrapper"].concat_arrs( *[obj.get_init_cash(group_by=cs_group_by) for obj in _objs], group_by=cs_group_by, ) ) if "init_position" not in kwargs: stack_init_position_objs = False for obj in _objs: if (to_1d_array(obj.init_position) != 0).any(): stack_init_position_objs = True break if stack_init_position_objs: kwargs["init_position"] = to_1d_array( kwargs["wrapper"].concat_arrs( *[obj.init_position for obj in _objs], group_by=False, ), ) else: kwargs["init_position"] = np.array([0.0]) if "init_price" not in kwargs: stack_init_price_objs = False for obj in _objs: if not np.isnan(to_1d_array(obj.init_price)).all(): stack_init_price_objs = True break if stack_init_price_objs: kwargs["init_price"] = to_1d_array( kwargs["wrapper"].concat_arrs( *[obj.init_price for obj in _objs], group_by=False, ), ) else: kwargs["init_price"] = np.array([np.nan]) if "cash_deposits" not in kwargs: stack_cash_deposits_objs = False for obj in _objs: if obj._cash_deposits.size > 1 or obj._cash_deposits.item() != 0: stack_cash_deposits_objs = True break if stack_cash_deposits_objs: kwargs["cash_deposits"] = kwargs["wrapper"].column_stack_arrs( *[obj.get_cash_deposits(group_by=cs_group_by) for obj in _objs], group_by=cs_group_by, reindex_kwargs=dict(fill_value=0), wrap=False, ) else: kwargs["cash_deposits"] = np.array([[0.0]]) if "cash_earnings" not in kwargs: stack_cash_earnings_objs = False for obj in _objs: if obj._cash_earnings.size > 1 or obj._cash_earnings.item() != 0: stack_cash_earnings_objs = True break if stack_cash_earnings_objs: kwargs["cash_earnings"] = kwargs["wrapper"].column_stack_arrs( *[obj.get_cash_earnings(group_by=False) for obj in _objs], group_by=False, reindex_kwargs=dict(fill_value=0), wrap=False, ) else: kwargs["cash_earnings"] = np.array([[0.0]]) if "call_seq" not in kwargs: stack_call_seq_objs = True for obj in _objs: if obj.config["call_seq"] is None: stack_call_seq_objs = False break if stack_call_seq_objs: kwargs["call_seq"] = kwargs["wrapper"].column_stack_arrs( *[obj.call_seq for obj in _objs], group_by=False, reindex_kwargs=dict(fill_value=0), wrap=False, ) if "bm_close" not in kwargs: stack_bm_close_objs = True for obj in _objs: if obj._bm_close is None or isinstance(obj._bm_close, bool): stack_bm_close_objs = False break if stack_bm_close_objs: new_bm_close = kwargs["wrapper"].column_stack_arrs( *[obj.bm_close for obj in _objs], group_by=False, wrap=False, ) if fbfill_close: new_bm_close = new_bm_close.vbt.fbfill() elif ffill_close: new_bm_close = new_bm_close.vbt.ffill() kwargs["bm_close"] = new_bm_close if "in_outputs" not in kwargs: kwargs["in_outputs"] = cls.column_stack_in_outputs(*_objs, **kwargs) if "sim_start" not in kwargs: kwargs["sim_start"] = cls.column_stack_sim_start(kwargs["wrapper"], *_objs) if "sim_end" not in kwargs: kwargs["sim_end"] = cls.column_stack_sim_end(kwargs["wrapper"], *_objs) if "weights" not in kwargs: stack_weights_objs = False obj_weights = [] for obj in objs: if obj.weights is not None: stack_weights_objs = True obj_weights.append(obj.weights) else: obj_weights.append([np.nan] * obj.wrapper.shape_2d[1]) if stack_weights_objs: kwargs["weights"] = to_1d_array( kwargs["wrapper"].concat_arrs( *obj_weights, group_by=False, ), ) kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, order_records: tp.Union[tp.RecordArray, enums.SimulationOutput], *, close: tp.ArrayLike, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, log_records: tp.Optional[tp.RecordArray] = None, cash_sharing: bool = False, init_cash: tp.Union[str, tp.ArrayLike] = "auto", init_position: tp.ArrayLike = 0.0, init_price: tp.ArrayLike = np.nan, cash_deposits: tp.ArrayLike = 0.0, cash_deposits_as_input: tp.Optional[bool] = None, cash_earnings: tp.ArrayLike = 0.0, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, call_seq: tp.Optional[tp.Array2d] = None, in_outputs: tp.Optional[tp.NamedTuple] = None, use_in_outputs: tp.Optional[bool] = None, bm_close: tp.Optional[tp.ArrayLike] = None, fillna_close: tp.Optional[bool] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, returns_acc_defaults: tp.KwargsLike = None, trades_type: tp.Optional[tp.Union[str, int]] = None, orders_cls: tp.Optional[type] = None, logs_cls: tp.Optional[type] = None, trades_cls: tp.Optional[type] = None, entry_trades_cls: tp.Optional[type] = None, exit_trades_cls: tp.Optional[type] = None, positions_cls: tp.Optional[type] = None, drawdowns_cls: tp.Optional[type] = None, weights: tp.Union[None, bool, tp.ArrayLike] = None, **kwargs, ) -> None: from vectorbtpro._settings import settings portfolio_cfg = settings["portfolio"] if cash_sharing: if wrapper.grouper.allow_enable or wrapper.grouper.allow_modify: wrapper = wrapper.replace(allow_enable=False, allow_modify=False) if isinstance(order_records, enums.SimulationOutput): sim_out = order_records order_records = sim_out.order_records log_records = sim_out.log_records cash_deposits = sim_out.cash_deposits cash_earnings = sim_out.cash_earnings sim_start = sim_out.sim_start sim_end = sim_out.sim_end call_seq = sim_out.call_seq in_outputs = sim_out.in_outputs close = to_2d_array(close) if open is not None: open = to_2d_array(open) if high is not None: high = to_2d_array(high) if low is not None: low = to_2d_array(low) if isinstance(init_cash, str): init_cash = map_enum_fields(init_cash, enums.InitCashMode) if not checks.is_int(init_cash) or init_cash not in enums.InitCashMode: init_cash = to_1d_array(init_cash) init_position = to_1d_array(init_position) init_price = to_1d_array(init_price) cash_deposits = to_2d_array(cash_deposits) cash_earnings = to_2d_array(cash_earnings) if cash_deposits_as_input is None: cash_deposits_as_input = portfolio_cfg["cash_deposits_as_input"] if bm_close is not None and not isinstance(bm_close, bool): bm_close = to_2d_array(bm_close) if log_records is None: log_records = np.array([], dtype=enums.log_dt) if use_in_outputs is None: use_in_outputs = portfolio_cfg["use_in_outputs"] if fillna_close is None: fillna_close = portfolio_cfg["fillna_close"] if weights is None: weights = portfolio_cfg["weights"] if trades_type is None: trades_type = portfolio_cfg["trades_type"] if isinstance(trades_type, str): trades_type = map_enum_fields(trades_type, enums.TradesType) Analyzable.__init__( self, wrapper, order_records=order_records, open=open, high=high, low=low, close=close, log_records=log_records, cash_sharing=cash_sharing, init_cash=init_cash, init_position=init_position, init_price=init_price, cash_deposits=cash_deposits, cash_deposits_as_input=cash_deposits_as_input, cash_earnings=cash_earnings, sim_start=sim_start, sim_end=sim_end, call_seq=call_seq, in_outputs=in_outputs, use_in_outputs=use_in_outputs, bm_close=bm_close, fillna_close=fillna_close, year_freq=year_freq, returns_acc_defaults=returns_acc_defaults, trades_type=trades_type, orders_cls=orders_cls, logs_cls=logs_cls, trades_cls=trades_cls, entry_trades_cls=entry_trades_cls, exit_trades_cls=exit_trades_cls, positions_cls=positions_cls, drawdowns_cls=drawdowns_cls, weights=weights, **kwargs, ) SimRangeMixin.__init__(self, sim_start=sim_start, sim_end=sim_end) self._open = open self._high = high self._low = low self._close = close self._order_records = order_records self._log_records = log_records self._cash_sharing = cash_sharing self._init_cash = init_cash self._init_position = init_position self._init_price = init_price self._cash_deposits = cash_deposits self._cash_deposits_as_input = cash_deposits_as_input self._cash_earnings = cash_earnings self._call_seq = call_seq self._in_outputs = in_outputs self._use_in_outputs = use_in_outputs self._bm_close = bm_close self._fillna_close = fillna_close self._year_freq = year_freq self._returns_acc_defaults = returns_acc_defaults self._trades_type = trades_type self._orders_cls = orders_cls self._logs_cls = logs_cls self._trades_cls = trades_cls self._entry_trades_cls = entry_trades_cls self._exit_trades_cls = exit_trades_cls self._positions_cls = positions_cls self._drawdowns_cls = drawdowns_cls self._weights = weights # Only slices of rows can be selected self._range_only_select = True # Copy writeable attrs self._in_output_config = type(self)._in_output_config.copy() # ############# In-outputs ############# # _in_output_config: tp.ClassVar[Config] = HybridConfig( dict( cash=dict(grouping="cash_sharing"), position=dict(grouping="columns"), debt=dict(grouping="columns"), locked_cash=dict(grouping="columns"), free_cash=dict(grouping="cash_sharing"), returns=dict(grouping="cash_sharing"), ) ) @property def in_output_config(self) -> Config: """In-output config of `${cls_name}`. ```python ${in_output_config} ``` Returns `${cls_name}._in_output_config`, which gets (hybrid-) copied upon creation of each instance. Thus, changing this config won't affect the class. To change in_outputs, you can either change the config in-place, override this property, or overwrite the instance variable `${cls_name}._in_output_config`. """ return self._in_output_config @classmethod def parse_field_options(cls, field: str) -> tp.Kwargs: """Parse options based on the name of a field. Returns a dictionary with the parsed grouping, object type, and cleaned field name. Grouping is parsed by looking for the following suffixes: * '_cs': per group if grouped with cash sharing, otherwise per column * '_pcg': per group if grouped, otherwise per column * '_pg': per group * '_pc': per column * '_records': records Object type is parsed by looking for the following suffixes: * '_2d': element per timestamp and column or group (time series) * '_1d': element per column or group (reduced time series) Those substrings are then removed to produce a clean field name.""" options = dict() new_parts = [] for part in field.split("_"): if part == "1d": options["obj_type"] = "red_array" elif part == "2d": options["obj_type"] = "array" elif part == "records": options["obj_type"] = "records" elif part == "pc": options["grouping"] = "columns" elif part == "pg": options["grouping"] = "groups" elif part == "pcg": options["grouping"] = "columns_or_groups" elif part == "cs": options["grouping"] = "cash_sharing" else: new_parts.append(part) field = "_".join(new_parts) options["field"] = field return options def matches_field_options( self, options: tp.Kwargs, obj_type: tp.Optional[str] = None, group_by_aware: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, ) -> bool: """Return whether options of a field match the requirements. Requirements include the type of the object (array, reduced array, records), the grouping of the object (1/2 dimensions, group/column-wise layout). The current grouping and cash sharing of this portfolio object are also taken into account. When an option is not in `options`, it's automatically marked as matching.""" field_obj_type = options.get("obj_type", None) field_grouping = options.get("grouping", None) if field_obj_type is not None and obj_type is not None: if field_obj_type != obj_type: return False if field_grouping is not None: if wrapper is None: wrapper = self.wrapper is_grouped = wrapper.grouper.is_grouped(group_by=group_by) if is_grouped: if group_by_aware: if field_grouping == "groups": return True if field_grouping == "columns_or_groups": return True if self.cash_sharing: if field_grouping == "cash_sharing": return True else: if field_grouping == "columns": return True if not self.cash_sharing: if field_grouping == "cash_sharing": return True else: if field_grouping == "columns": return True if field_grouping == "columns_or_groups": return True if field_grouping == "cash_sharing": return True return False return True def wrap_obj( self, obj: tp.Any, obj_name: tp.Optional[str] = None, grouping: str = "columns_or_groups", obj_type: tp.Optional[str] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_func: tp.Optional[tp.Callable] = None, wrap_kwargs: tp.KwargsLike = None, force_wrapping: bool = False, silence_warnings: bool = False, **kwargs, ) -> tp.Any: """Wrap an object. `wrap_func` must take the portfolio, `obj`, all the arguments passed to this method, and `**kwargs`. If you don't need any of the arguments, make `indexing_func` accept them as `**kwargs`. If the object is None or boolean, returns as-is.""" if obj is None or isinstance(obj, bool): return obj if wrapper is None: wrapper = self.wrapper is_grouped = wrapper.grouper.is_grouped(group_by=group_by) if wrap_func is not None: return wrap_func( self, obj, obj_name=obj_name, grouping=grouping, obj_type=obj_type, wrapper=wrapper, group_by=group_by, wrap_kwargs=wrap_kwargs, force_wrapping=force_wrapping, silence_warnings=silence_warnings, **kwargs, ) def _wrap_reduced_grouped(obj): _wrap_kwargs = merge_dicts(dict(name_or_index=obj_name), wrap_kwargs) return wrapper.wrap_reduced(obj, group_by=group_by, **_wrap_kwargs) def _wrap_reduced(obj): _wrap_kwargs = merge_dicts(dict(name_or_index=obj_name), wrap_kwargs) return wrapper.wrap_reduced(obj, group_by=False, **_wrap_kwargs) def _wrap_grouped(obj): return wrapper.wrap(obj, group_by=group_by, **resolve_dict(wrap_kwargs)) def _wrap(obj): return wrapper.wrap(obj, group_by=False, **resolve_dict(wrap_kwargs)) if obj_type is not None and obj_type not in {"records"}: if grouping == "cash_sharing": if obj_type == "array": if is_grouped and self.cash_sharing: return _wrap_grouped(obj) return _wrap(obj) if obj_type == "red_array": if is_grouped and self.cash_sharing: return _wrap_reduced_grouped(obj) return _wrap_reduced(obj) if obj.ndim == 2: if is_grouped and self.cash_sharing: return _wrap_grouped(obj) return _wrap(obj) if obj.ndim == 1: if is_grouped and self.cash_sharing: return _wrap_reduced_grouped(obj) return _wrap_reduced(obj) if grouping == "columns_or_groups": if obj_type == "array": if is_grouped: return _wrap_grouped(obj) return _wrap(obj) if obj_type == "red_array": if is_grouped: return _wrap_reduced_grouped(obj) return _wrap_reduced(obj) if obj.ndim == 2: if is_grouped: return _wrap_grouped(obj) return _wrap(obj) if obj.ndim == 1: if is_grouped: return _wrap_reduced_grouped(obj) return _wrap_reduced(obj) if grouping == "groups": if obj_type == "array": return _wrap_grouped(obj) if obj_type == "red_array": return _wrap_reduced_grouped(obj) if obj.ndim == 2: return _wrap_grouped(obj) if obj.ndim == 1: return _wrap_reduced_grouped(obj) if grouping == "columns": if obj_type == "array": return _wrap(obj) if obj_type == "red_array": return _wrap_reduced(obj) if obj.ndim == 2: return _wrap(obj) if obj.ndim == 1: return _wrap_reduced(obj) if obj_type not in {"records"}: if checks.is_np_array(obj) and not checks.is_record_array(obj): if is_grouped: if obj_type is not None and obj_type == "array": return _wrap_grouped(obj) if obj_type is not None and obj_type == "red_array": return _wrap_reduced_grouped(obj) if to_2d_shape(obj.shape) == wrapper.get_shape_2d(): return _wrap_grouped(obj) if obj.shape == (wrapper.get_shape_2d()[1],): return _wrap_reduced_grouped(obj) if obj_type is not None and obj_type == "array": return _wrap(obj) if obj_type is not None and obj_type == "red_array": return _wrap_reduced(obj) if to_2d_shape(obj.shape) == wrapper.shape_2d: return _wrap(obj) if obj.shape == (wrapper.shape_2d[1],): return _wrap_reduced(obj) if force_wrapping: raise NotImplementedError(f"Cannot wrap object '{obj_name}'") if not silence_warnings: warn(f"Cannot figure out how to wrap object '{obj_name}'") return obj def get_in_output( self, field: str, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> tp.Union[None, bool, tp.AnyArray]: """Find and wrap an in-output object matching the field. If the field can be found in the attributes of this `Portfolio` instance, reads the attribute's options to get requirements for the type and layout of the in-output object. For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with `Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`. If `field` is not in `Portfolio.in_outputs`, searches for the field in aliases and options. In such case, to narrow down the number of candidates, options are additionally matched against the requirements using `Portfolio.matches_field_options`. Finally, the matched in-output object is wrapped using `Portfolio.wrap_obj`.""" if self.in_outputs is None: raise ValueError("No in-outputs attached") if field in self.cls_dir: prop = getattr(type(self), field) prop_options = getattr(prop, "options", {}) obj_type = prop_options.get("obj_type", "array") group_by_aware = prop_options.get("group_by_aware", True) wrap_func = prop_options.get("wrap_func", None) wrap_kwargs = prop_options.get("wrap_kwargs", None) force_wrapping = prop_options.get("force_wrapping", False) silence_warnings = prop_options.get("silence_warnings", False) field_aliases = prop_options.get("field_aliases", None) if field_aliases is None: field_aliases = [] field_aliases = {field, *field_aliases} found_attr = True else: obj_type = None group_by_aware = True wrap_func = None wrap_kwargs = None force_wrapping = False silence_warnings = False field_aliases = {field} found_attr = False found_field = None found_field_options = None for _field in set(self.in_outputs._fields): _field_options = merge_dicts( self.parse_field_options(_field), self.in_output_config.get(_field, None), ) if (not found_attr and field == _field) or ( (_field in field_aliases or _field_options.get("field", _field) in field_aliases) and self.matches_field_options( _field_options, obj_type=obj_type, group_by_aware=group_by_aware, wrapper=wrapper, group_by=group_by, ) ): if found_field is not None: raise ValueError(f"Multiple fields for '{field}' found in in_outputs") found_field = _field found_field_options = _field_options if found_field is None: raise AttributeError(f"No compatible field for '{field}' found in in_outputs") obj = getattr(self.in_outputs, found_field) if found_attr and checks.is_np_array(obj) and obj.shape == (0, 0): # for returns return None kwargs = merge_dicts( dict( grouping=found_field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"), obj_type=found_field_options.get("obj_type", obj_type), wrap_func=found_field_options.get("wrap_func", wrap_func), wrap_kwargs=found_field_options.get("wrap_kwargs", wrap_kwargs), force_wrapping=found_field_options.get("force_wrapping", force_wrapping), silence_warnings=found_field_options.get("silence_warnings", silence_warnings), ), kwargs, ) return self.wrap_obj( obj, found_field_options.get("field", found_field), wrapper=wrapper, group_by=group_by, **kwargs, ) # ############# Indexing ############# # def index_obj( self, obj: tp.Any, wrapper_meta: dict, obj_name: tp.Optional[str] = None, grouping: str = "columns_or_groups", obj_type: tp.Optional[str] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, indexing_func: tp.Optional[tp.Callable] = None, force_indexing: bool = False, silence_warnings: bool = False, **kwargs, ) -> tp.Any: """Perform indexing on an object. `indexing_func` must take the portfolio, all the arguments passed to this method, and `**kwargs`. If you don't need any of the arguments, make `indexing_func` accept them as `**kwargs`. If the object is None, boolean, or empty, returns as-is.""" if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0): return obj if wrapper is None: wrapper = self.wrapper if indexing_func is not None: return indexing_func( self, obj, wrapper_meta, obj_name=obj_name, grouping=grouping, obj_type=obj_type, wrapper=wrapper, group_by=group_by, force_indexing=force_indexing, silence_warnings=silence_warnings, **kwargs, ) def _index_1d_by_group(obj: tp.ArrayLike) -> tp.ArrayLike: return to_1d_array(obj)[wrapper_meta["group_idxs"]] def _index_1d_by_col(obj: tp.ArrayLike) -> tp.ArrayLike: return to_1d_array(obj)[wrapper_meta["col_idxs"]] def _index_2d_by_group(obj: tp.ArrayLike) -> tp.ArrayLike: return to_2d_array(obj)[wrapper_meta["row_idxs"], :][:, wrapper_meta["group_idxs"]] def _index_2d_by_col(obj: tp.ArrayLike) -> tp.ArrayLike: return to_2d_array(obj)[wrapper_meta["row_idxs"], :][:, wrapper_meta["col_idxs"]] def _index_records(obj: tp.RecordArray) -> tp.RecordArray: records = Records(wrapper, obj) records_meta = records.indexing_func_meta(wrapper_meta=wrapper_meta) return records.indexing_func(records_meta=records_meta).values is_grouped = wrapper.grouper.is_grouped(group_by=group_by) if obj_type is not None and obj_type == "records": return _index_records(obj) if grouping == "cash_sharing": if obj_type is not None and obj_type == "array": if is_grouped and self.cash_sharing: return _index_2d_by_group(obj) return _index_2d_by_col(obj) if obj_type is not None and obj_type == "red_array": if is_grouped and self.cash_sharing: return _index_1d_by_group(obj) return _index_1d_by_col(obj) if obj.ndim == 2: if is_grouped and self.cash_sharing: return _index_2d_by_group(obj) return _index_2d_by_col(obj) if obj.ndim == 1: if is_grouped and self.cash_sharing: return _index_1d_by_group(obj) return _index_1d_by_col(obj) if grouping == "columns_or_groups": if obj_type is not None and obj_type == "array": if is_grouped: return _index_2d_by_group(obj) return _index_2d_by_col(obj) if obj_type is not None and obj_type == "red_array": if is_grouped: return _index_1d_by_group(obj) return _index_1d_by_col(obj) if obj.ndim == 2: if is_grouped: return _index_2d_by_group(obj) return _index_2d_by_col(obj) if obj.ndim == 1: if is_grouped: return _index_1d_by_group(obj) return _index_1d_by_col(obj) if grouping == "groups": if obj_type is not None and obj_type == "array": return _index_2d_by_group(obj) if obj_type is not None and obj_type == "red_array": return _index_1d_by_group(obj) if obj.ndim == 2: return _index_2d_by_group(obj) if obj.ndim == 1: return _index_1d_by_group(obj) if grouping == "columns": if obj_type is not None and obj_type == "array": return _index_2d_by_col(obj) if obj_type is not None and obj_type == "red_array": return _index_1d_by_col(obj) if obj.ndim == 2: return _index_2d_by_col(obj) if obj.ndim == 1: return _index_1d_by_col(obj) if checks.is_np_array(obj): if is_grouped: if obj_type is not None and obj_type == "array": return _index_2d_by_group(obj) if obj_type is not None and obj_type == "red_array": return _index_1d_by_group(obj) if to_2d_shape(obj.shape) == wrapper.get_shape_2d(): return _index_2d_by_group(obj) if obj.shape == (wrapper.get_shape_2d()[1],): return _index_1d_by_group(obj) if obj_type is not None and obj_type == "array": return _index_2d_by_col(obj) if obj_type is not None and obj_type == "red_array": return _index_1d_by_col(obj) if to_2d_shape(obj.shape) == wrapper.shape_2d: return _index_2d_by_col(obj) if obj.shape == (wrapper.shape_2d[1],): return _index_1d_by_col(obj) if force_indexing: raise NotImplementedError(f"Cannot index object '{obj_name}'") if not silence_warnings: warn(f"Cannot figure out how to index object '{obj_name}'") return obj def in_outputs_indexing_func(self, wrapper_meta: dict, **kwargs) -> tp.Optional[tp.NamedTuple]: """Perform indexing on `Portfolio.in_outputs`. If the field can be found in the attributes of this `Portfolio` instance, reads the attribute's options to get requirements for the type and layout of the in-output object. For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with `Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`. Performs indexing on the in-output object using `Portfolio.index_obj`.""" if self.in_outputs is None: return None new_in_outputs = {} for field, obj in self.in_outputs._asdict().items(): field_options = merge_dicts( self.parse_field_options(field), self.in_output_config.get(field, None), ) if field_options.get("field", field) in self.cls_dir: prop = getattr(type(self), field_options["field"]) prop_options = getattr(prop, "options", {}) obj_type = prop_options.get("obj_type", "array") group_by_aware = prop_options.get("group_by_aware", True) indexing_func = prop_options.get("indexing_func", None) force_indexing = prop_options.get("force_indexing", False) silence_warnings = prop_options.get("silence_warnings", False) else: obj_type = None group_by_aware = True indexing_func = None force_indexing = False silence_warnings = False _kwargs = merge_dicts( dict( grouping=field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"), obj_name=field_options.get("field", field), obj_type=field_options.get("obj_type", obj_type), indexing_func=field_options.get("indexing_func", indexing_func), force_indexing=field_options.get("force_indexing", force_indexing), silence_warnings=field_options.get("silence_warnings", silence_warnings), ), kwargs, ) new_obj = self.index_obj(obj, wrapper_meta, **_kwargs) new_in_outputs[field] = new_obj return type(self.in_outputs)(**new_in_outputs) def indexing_func( self: PortfolioT, *args, in_output_kwargs: tp.KwargsLike = None, wrapper_meta: tp.DictLike = None, **kwargs, ) -> PortfolioT: """Perform indexing on `Portfolio`. In-outputs are indexed using `Portfolio.in_outputs_indexing_func`.""" _self = self.disable_weights() if wrapper_meta is None: wrapper_meta = _self.wrapper.indexing_func_meta( *args, column_only_select=_self.column_only_select, range_only_select=_self.range_only_select, group_select=_self.group_select, **kwargs, ) new_wrapper = wrapper_meta["new_wrapper"] row_idxs = wrapper_meta["row_idxs"] rows_changed = wrapper_meta["rows_changed"] col_idxs = wrapper_meta["col_idxs"] columns_changed = wrapper_meta["columns_changed"] group_idxs = wrapper_meta["group_idxs"] new_close = ArrayWrapper.select_from_flex_array( _self._close, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) if _self._open is not None: new_open = ArrayWrapper.select_from_flex_array( _self._open, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) else: new_open = _self._open if _self._high is not None: new_high = ArrayWrapper.select_from_flex_array( _self._high, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) else: new_high = _self._high if _self._low is not None: new_low = ArrayWrapper.select_from_flex_array( _self._low, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) else: new_low = _self._low new_order_records = _self.orders.indexing_func_meta(wrapper_meta=wrapper_meta)["new_records_arr"] new_log_records = _self.logs.indexing_func_meta(wrapper_meta=wrapper_meta)["new_records_arr"] new_init_cash = _self._init_cash if not checks.is_int(new_init_cash): new_init_cash = to_1d_array(new_init_cash) if rows_changed and row_idxs.start > 0: if _self.wrapper.grouper.is_grouped() and not _self.cash_sharing: cash = _self.get_cash(group_by=False) else: cash = _self.cash new_init_cash = to_1d_array(cash.iloc[row_idxs.start - 1]) if columns_changed and new_init_cash.shape[0] > 1: if _self.cash_sharing: new_init_cash = new_init_cash[group_idxs] else: new_init_cash = new_init_cash[col_idxs] new_init_position = to_1d_array(_self._init_position) if rows_changed and row_idxs.start > 0: new_init_position = to_1d_array(_self.assets.iloc[row_idxs.start - 1]) if columns_changed and new_init_position.shape[0] > 1: new_init_position = new_init_position[col_idxs] new_init_price = to_1d_array(_self._init_price) if rows_changed and row_idxs.start > 0: new_init_price = to_1d_array(_self.close.iloc[: row_idxs.start].ffill().iloc[-1]) if columns_changed and new_init_price.shape[0] > 1: new_init_price = new_init_price[col_idxs] new_cash_deposits = ArrayWrapper.select_from_flex_array( _self._cash_deposits, row_idxs=row_idxs, col_idxs=group_idxs if _self.cash_sharing else col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) new_cash_earnings = ArrayWrapper.select_from_flex_array( _self._cash_earnings, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) if _self._call_seq is not None: new_call_seq = ArrayWrapper.select_from_flex_array( _self._call_seq, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) else: new_call_seq = None if _self._bm_close is not None and not isinstance(_self._bm_close, bool): new_bm_close = ArrayWrapper.select_from_flex_array( _self._bm_close, row_idxs=row_idxs, col_idxs=col_idxs, rows_changed=rows_changed, columns_changed=columns_changed, ) else: new_bm_close = _self._bm_close new_in_outputs = _self.in_outputs_indexing_func(wrapper_meta, **resolve_dict(in_output_kwargs)) new_sim_start = _self.sim_start_indexing_func(wrapper_meta) new_sim_end = _self.sim_end_indexing_func(wrapper_meta) if self.weights is not None: new_weights = to_1d_array(self.weights) if columns_changed and new_weights.shape[0] > 1: new_weights = new_weights[col_idxs] else: new_weights = self._weights return self.replace( wrapper=new_wrapper, order_records=new_order_records, open=new_open, high=new_high, low=new_low, close=new_close, log_records=new_log_records, init_cash=new_init_cash, init_position=new_init_position, init_price=new_init_price, cash_deposits=new_cash_deposits, cash_earnings=new_cash_earnings, call_seq=new_call_seq, in_outputs=new_in_outputs, bm_close=new_bm_close, sim_start=new_sim_start, sim_end=new_sim_end, weights=new_weights, ) # ############# Resampling ############# # def resample_obj( self, obj: tp.Any, resampler: tp.Union[Resampler, tp.PandasResampler], obj_name: tp.Optional[str] = None, obj_type: tp.Optional[str] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, resample_func: tp.Union[None, str, tp.Callable] = None, resample_kwargs: tp.KwargsLike = None, force_resampling: bool = False, silence_warnings: bool = False, **kwargs, ) -> tp.Any: """Resample an object. `resample_func` must take the portfolio, `obj`, `resampler`, all the arguments passed to this method, and `**kwargs`. If you don't need any of the arguments, make `resample_func` accept them as `**kwargs`. If `resample_func` is a string, will use it as `reduce_func_nb` in `vectorbtpro.generic.accessors.GenericAccessor.resample_apply`. Default is 'last'. If the object is None, boolean, or empty, returns as-is.""" if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0): return obj if wrapper is None: wrapper = self.wrapper if resample_func is None: resample_func = "last" if not isinstance(resample_func, str): return resample_func( self, obj, resampler, obj_name=obj_name, obj_type=obj_type, wrapper=wrapper, group_by=group_by, resample_kwargs=resample_kwargs, force_resampling=force_resampling, silence_warnings=silence_warnings, **kwargs, ) def _resample(obj: tp.Array) -> tp.SeriesFrame: wrapped_obj = ArrayWrapper.from_obj(obj, index=wrapper.index).wrap(obj) return wrapped_obj.vbt.resample_apply(resampler, resample_func, **resolve_dict(resample_kwargs)).values if obj_type is not None and obj_type == "red_array": return obj if obj_type is None or obj_type == "array": is_grouped = wrapper.grouper.is_grouped(group_by=group_by) if checks.is_np_array(obj): if is_grouped: if to_2d_shape(obj.shape) == wrapper.get_shape_2d(): return _resample(obj) if obj.shape == (wrapper.get_shape_2d()[1],): return obj if to_2d_shape(obj.shape) == wrapper.shape_2d: return _resample(obj) if obj.shape == (wrapper.shape_2d[1],): return obj if force_resampling: raise NotImplementedError(f"Cannot resample object '{obj_name}'") if not silence_warnings: warn(f"Cannot figure out how to resample object '{obj_name}'") return obj def resample_in_outputs( self, resampler: tp.Union[Resampler, tp.PandasResampler], **kwargs, ) -> tp.Optional[tp.NamedTuple]: """Resample `Portfolio.in_outputs`. If the field can be found in the attributes of this `Portfolio` instance, reads the attribute's options to get requirements for the type and layout of the in-output object. For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with `Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`. Performs indexing on the in-output object using `Portfolio.resample_obj`.""" if self.in_outputs is None: return None new_in_outputs = {} for field, obj in self.in_outputs._asdict().items(): field_options = merge_dicts( self.parse_field_options(field), self.in_output_config.get(field, None), ) if field_options.get("field", field) in self.cls_dir: prop = getattr(type(self), field_options["field"]) prop_options = getattr(prop, "options", {}) obj_type = prop_options.get("obj_type", "array") resample_func = prop_options.get("resample_func", None) resample_kwargs = prop_options.get("resample_kwargs", None) force_resampling = prop_options.get("force_resampling", False) silence_warnings = prop_options.get("silence_warnings", False) else: obj_type = None resample_func = None resample_kwargs = None force_resampling = False silence_warnings = False _kwargs = merge_dicts( dict( obj_name=field_options.get("field", field), obj_type=field_options.get("obj_type", obj_type), resample_func=field_options.get("resample_func", resample_func), resample_kwargs=field_options.get("resample_kwargs", resample_kwargs), force_resampling=field_options.get("force_resampling", force_resampling), silence_warnings=field_options.get("silence_warnings", silence_warnings), ), kwargs, ) new_obj = self.resample_obj(obj, resampler, **_kwargs) new_in_outputs[field] = new_obj return type(self.in_outputs)(**new_in_outputs) def resample( self: PortfolioT, *args, ffill_close: bool = False, fbfill_close: bool = False, in_output_kwargs: tp.KwargsLike = None, wrapper_meta: tp.DictLike = None, **kwargs, ) -> PortfolioT: """Resample the `Portfolio` instance. !!! warning Downsampling is associated with information loss: * Cash deposits and earnings are assumed to be added/removed at the beginning of each time step. Imagine depositing $100 and using them up in the same bar, and then depositing another $100 and using them up. Downsampling both bars into a single bar will aggregate cash deposits and earnings, and put both of them at the beginning of the new bar, even though the second deposit was added later in time. * Market/benchmark returns are computed by applying the initial value on the close price of the first bar and by tracking the price change to simulate holding. Moving the close price of the first bar further into the future will affect this computation and almost certainly produce a different market value and returns. To mitigate this, make sure to downsample to an index with the first bar containing only the first bar from the origin timeframe.""" _self = self.disable_weights() if _self._call_seq is not None: raise ValueError("Cannot resample call_seq") if wrapper_meta is None: wrapper_meta = _self.wrapper.resample_meta(*args, **kwargs) resampler = wrapper_meta["resampler"] new_wrapper = wrapper_meta["new_wrapper"] new_close = _self.close.vbt.resample_apply(resampler, "last") if fbfill_close: new_close = new_close.vbt.fbfill() elif ffill_close: new_close = new_close.vbt.ffill() new_close = new_close.values if _self._open is not None: new_open = _self.open.vbt.resample_apply(resampler, "first").values else: new_open = _self._open if _self._high is not None: new_high = _self.high.vbt.resample_apply(resampler, "max").values else: new_high = _self._high if _self._low is not None: new_low = _self.low.vbt.resample_apply(resampler, "min").values else: new_low = _self._low new_order_records = _self.orders.resample_records_arr(resampler) new_log_records = _self.logs.resample_records_arr(resampler) if _self._cash_deposits.size > 1 or _self._cash_deposits.item() != 0: new_cash_deposits = _self.get_cash_deposits(group_by=None if _self.cash_sharing else False) new_cash_deposits = new_cash_deposits.vbt.resample_apply(resampler, generic_nb.sum_reduce_nb) new_cash_deposits = new_cash_deposits.fillna(0.0) new_cash_deposits = new_cash_deposits.values else: new_cash_deposits = _self._cash_deposits if _self._cash_earnings.size > 1 or _self._cash_earnings.item() != 0: new_cash_earnings = _self.get_cash_earnings(group_by=False) new_cash_earnings = new_cash_earnings.vbt.resample_apply(resampler, generic_nb.sum_reduce_nb) new_cash_earnings = new_cash_earnings.fillna(0.0) new_cash_earnings = new_cash_earnings.values else: new_cash_earnings = _self._cash_earnings if _self._bm_close is not None and not isinstance(_self._bm_close, bool): new_bm_close = _self.bm_close.vbt.resample_apply(resampler, "last") if fbfill_close: new_bm_close = new_bm_close.vbt.fbfill() elif ffill_close: new_bm_close = new_bm_close.vbt.ffill() new_bm_close = new_bm_close.values else: new_bm_close = _self._bm_close if _self._in_outputs is not None: new_in_outputs = _self.resample_in_outputs(resampler, **resolve_dict(in_output_kwargs)) else: new_in_outputs = None new_sim_start = _self.resample_sim_start(new_wrapper) new_sim_end = _self.resample_sim_end(new_wrapper) return self.replace( wrapper=new_wrapper, order_records=new_order_records, open=new_open, high=new_high, low=new_low, close=new_close, log_records=new_log_records, cash_deposits=new_cash_deposits, cash_earnings=new_cash_earnings, in_outputs=new_in_outputs, bm_close=new_bm_close, sim_start=new_sim_start, sim_end=new_sim_end, ) # ############# Class methods ############# # @classmethod def from_orders( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin, FOPreparer, PFPrepResult], size: tp.Optional[tp.ArrayLike] = None, size_type: tp.Optional[tp.ArrayLike] = None, direction: tp.Optional[tp.ArrayLike] = None, price: tp.Optional[tp.ArrayLike] = None, fees: tp.Optional[tp.ArrayLike] = None, fixed_fees: tp.Optional[tp.ArrayLike] = None, slippage: tp.Optional[tp.ArrayLike] = None, min_size: tp.Optional[tp.ArrayLike] = None, max_size: tp.Optional[tp.ArrayLike] = None, size_granularity: tp.Optional[tp.ArrayLike] = None, leverage: tp.Optional[tp.ArrayLike] = None, leverage_mode: tp.Optional[tp.ArrayLike] = None, reject_prob: tp.Optional[tp.ArrayLike] = None, price_area_vio_mode: tp.Optional[tp.ArrayLike] = None, allow_partial: tp.Optional[tp.ArrayLike] = None, raise_reject: tp.Optional[tp.ArrayLike] = None, log: tp.Optional[tp.ArrayLike] = None, val_price: tp.Optional[tp.ArrayLike] = None, from_ago: tp.Optional[tp.ArrayLike] = None, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, init_cash: tp.Optional[tp.ArrayLike] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, cash_earnings: tp.Optional[tp.ArrayLike] = None, cash_dividends: tp.Optional[tp.ArrayLike] = None, cash_sharing: tp.Optional[bool] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, call_seq: tp.Optional[tp.ArrayLike] = None, attach_call_seq: tp.Optional[bool] = None, ffill_val_price: tp.Optional[bool] = None, update_value: tp.Optional[bool] = None, save_state: tp.Optional[bool] = None, save_value: tp.Optional[bool] = None, save_returns: tp.Optional[bool] = None, skip_empty: tp.Optional[bool] = None, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = None, seed: tp.Optional[int] = None, group_by: tp.GroupByLike = None, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, bm_close: tp.Optional[tp.ArrayLike] = None, records: tp.Optional[tp.RecordsLike] = None, return_preparer: bool = False, return_prep_result: bool = False, return_sim_out: bool = False, **kwargs, ) -> PortfolioResultT: """Simulate portfolio from orders - size, price, fees, and other information. See `vectorbtpro.portfolio.nb.from_orders.from_orders_nb`. Prepared by `vectorbtpro.portfolio.preparing.FOPreparer`. Args: close (array_like, OHLCDataMixin, FOPreparer, or PFPrepResult): Latest asset price at each time step. Will broadcast. Used for calculating unrealized PnL and portfolio value. If an instance of `vectorbtpro.data.base.OHLCDataMixin`, will extract the open, high, low, and close price. If an instance of `vectorbtpro.portfolio.preparing.FOPreparer`, will use it as a preparer. If an instance of `vectorbtpro.portfolio.preparing.PFPrepResult`, will use it as a preparer result. size (float or array_like): Size to order. See `vectorbtpro.portfolio.enums.Order.size`. Will broadcast. size_type (SizeType or array_like): See `vectorbtpro.portfolio.enums.SizeType` and `vectorbtpro.portfolio.enums.Order.size_type`. Will broadcast. direction (Direction or array_like): See `vectorbtpro.portfolio.enums.Direction` and `vectorbtpro.portfolio.enums.Order.direction`. Will broadcast. price (array_like of float): Order price. Will broadcast. See `vectorbtpro.portfolio.enums.Order.price`. Can be also provided as `vectorbtpro.portfolio.enums.PriceType`. Options `PriceType.NextOpen` and `PriceType.NextClose` are only applicable per column, that is, they cannot be used inside full arrays. In addition, they require the argument `from_ago` to be None. fees (float or array_like): Fees in percentage of the order value. See `vectorbtpro.portfolio.enums.Order.fees`. Will broadcast. fixed_fees (float or array_like): Fixed amount of fees to pay per order. See `vectorbtpro.portfolio.enums.Order.fixed_fees`. Will broadcast. slippage (float or array_like): Slippage in percentage of price. See `vectorbtpro.portfolio.enums.Order.slippage`. Will broadcast. min_size (float or array_like): Minimum size for an order to be accepted. See `vectorbtpro.portfolio.enums.Order.min_size`. Will broadcast. max_size (float or array_like): Maximum size for an order. See `vectorbtpro.portfolio.enums.Order.max_size`. Will broadcast. Will be partially filled if exceeded. size_granularity (float or array_like): Granularity of the size. See `vectorbtpro.portfolio.enums.Order.size_granularity`. Will broadcast. leverage (float or array_like): Leverage. See `vectorbtpro.portfolio.enums.Order.leverage`. Will broadcast. leverage_mode (LeverageMode or array_like): Leverage mode. See `vectorbtpro.portfolio.enums.Order.leverage_mode`. Will broadcast. reject_prob (float or array_like): Order rejection probability. See `vectorbtpro.portfolio.enums.Order.reject_prob`. Will broadcast. price_area_vio_mode (PriceAreaVioMode or array_like): See `vectorbtpro.portfolio.enums.PriceAreaVioMode`. Will broadcast. allow_partial (bool or array_like): Whether to allow partial fills. See `vectorbtpro.portfolio.enums.Order.allow_partial`. Will broadcast. Does not apply when size is `np.inf`. raise_reject (bool or array_like): Whether to raise an exception if order gets rejected. See `vectorbtpro.portfolio.enums.Order.raise_reject`. Will broadcast. log (bool or array_like): Whether to log orders. See `vectorbtpro.portfolio.enums.Order.log`. Will broadcast. val_price (array_like of float): Asset valuation price. Will broadcast. Can be also provided as `vectorbtpro.portfolio.enums.ValPriceType`. * Any `-np.inf` element is replaced by the latest valuation price (`open` or the latest known valuation price if `ffill_val_price`). * Any `np.inf` element is replaced by the current order price. Used at the time of decision making to calculate value of each asset in the group, for example, to convert target value into target amount. !!! note In contrast to `Portfolio.from_order_func`, order price is known beforehand (kind of), thus `val_price` is set to the current order price (using `np.inf`) by default. To valuate using previous close, set it in the settings to `-np.inf`. !!! note Make sure to use timestamp for `val_price` that comes before timestamps of all orders in the group with cash sharing (previous `close` for example), otherwise you're cheating yourself. open (array_like of float): First asset price at each time step. Defaults to `np.nan`. Will broadcast. Used as a price boundary (see `vectorbtpro.portfolio.enums.PriceArea`). high (array_like of float): Highest asset price at each time step. Defaults to `np.nan`. Will broadcast. Used as a price boundary (see `vectorbtpro.portfolio.enums.PriceArea`). low (array_like of float): Lowest asset price at each time step. Defaults to `np.nan`. Will broadcast. Used as a price boundary (see `vectorbtpro.portfolio.enums.PriceArea`). init_cash (InitCashMode, float or array_like): Initial capital. By default, will broadcast to the final number of columns. But if cash sharing is enabled, will broadcast to the number of groups. See `vectorbtpro.portfolio.enums.InitCashMode` to find optimal initial cash. !!! note Mode `InitCashMode.AutoAlign` is applied after the portfolio is initialized to set the same initial cash for all columns/groups. Changing grouping will change the initial cash, so be aware when indexing. init_position (float or array_like): Initial position. By default, will broadcast to the final number of columns. init_price (float or array_like): Initial position price. By default, will broadcast to the final number of columns. cash_deposits (float or array_like): Cash to be deposited/withdrawn at each timestamp. Will broadcast to the final shape. Must have the same number of columns as `init_cash`. Applied at the beginning of each timestamp. cash_earnings (float or array_like): Earnings in cash to be added at each timestamp. Will broadcast to the final shape. Applied at the end of each timestamp. cash_dividends (float or array_like): Dividends in cash to be added at each timestamp. Will broadcast to the final shape. Gets multiplied by the position and saved into `cash_earnings`. Applied at the end of each timestamp. cash_sharing (bool): Whether to share cash within the same group. If `group_by` is None and `cash_sharing` is True, `group_by` becomes True to form a single group with cash sharing. !!! warning Introduces cross-asset dependencies. This method presumes that in a group of assets that share the same capital all orders will be executed within the same tick and retain their price regardless of their position in the queue, even though they depend upon each other and thus cannot be executed in parallel. from_ago (int or array_like): Take order information from a number of bars ago. Will broadcast. Negative numbers will be cast to positive to avoid the look-ahead bias. Defaults to 0. Remember to account of it if you're using a custom signal function! sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive). Can be "auto", which will be substituted by the index of the first non-NA size value. sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive). Can be "auto", which will be substituted by the index of the first non-NA size value. call_seq (CallSeqType or array_like): Default sequence of calls per row and group. Each value in this sequence must indicate the position of column in the group to call next. Processing of `call_seq` goes always from left to right. For example, `[2, 0, 1]` would first call column 'c', then 'a', and finally 'b'. Supported are multiple options: * Set to None to generate the default call sequence on the fly. Will create a full array only if `attach_call_seq` is True. * Use `vectorbtpro.portfolio.enums.CallSeqType` to create a full array of a specific type. * Set to array to specify a custom call sequence. If `CallSeqType.Auto` selected, rearranges calls dynamically based on order value. Calculates value of all orders per row and group, and sorts them by this value. Sell orders will be executed first to release funds for buy orders. !!! warning `CallSeqType.Auto` should be used with caution: * It not only presumes that order prices are known beforehand, but also that orders can be executed in arbitrary order and still retain their price. In reality, this is hardly the case: after processing one asset, some time has passed and the price for other assets might have already changed. * Even if you're able to specify a slippage large enough to compensate for this behavior, slippage itself should depend upon execution order. This method doesn't let you do that. * Orders in the same queue are executed regardless of whether previous orders have been filled, which can leave them without required funds. For more control, use `Portfolio.from_order_func`. attach_call_seq (bool): Whether to attach `call_seq` to the instance. Makes sense if you want to analyze the simulation order. Otherwise, just takes memory. ffill_val_price (bool): Whether to track valuation price only if it's known. Otherwise, unknown `close` will lead to NaN in valuation price at the next timestamp. update_value (bool): Whether to update group value after each filled order. save_state (bool): Whether to save the state. The arrays will be available as `cash`, `position`, `debt`, `locked_cash`, and `free_cash` in in-outputs. save_value (bool): Whether to save the value. The array will be available as `value` in in-outputs. save_returns (bool): Whether to save the returns. The array will be available as `returns` in in-outputs. skip_empty (bool): Whether to skip rows with no order. max_order_records (int): The max number of order records expected to be filled at each column. Defaults to the maximum number of non-NaN values across all columns of the size array. Set to a lower number if you run out of memory, and to 0 to not fill. max_log_records (int): The max number of log records expected to be filled at each column. Defaults to the maximum number of True values across all columns of the log array. Set to a lower number if you run out of memory, and to 0 to not fill. seed (int): Seed to be set for both `call_seq` and at the beginning of the simulation. group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`. broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`. jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. chunked (any): See `vectorbtpro.utils.chunking.resolve_chunked_option`. bm_close (array_like): Latest benchmark price at each time step. Will broadcast. If not provided, will use `close`. If False, will not use any benchmark. records (array_like): Records to construct arrays from. See `vectorbtpro.base.indexing.IdxRecords`. return_preparer (bool): Whether to return the preparer of the type `vectorbtpro.portfolio.preparing.FOPreparer`. !!! note Seed won't be set in this case, you need to explicitly call `preparer.set_seed()`. return_prep_result (bool): Whether to return the preparer result of the type `vectorbtpro.portfolio.preparing.PFPrepResult`. return_sim_out (bool): Whether to return the simulation output of the type `vectorbtpro.portfolio.enums.SimulationOutput`. **kwargs: Keyword arguments passed to the `Portfolio` constructor. All broadcastable arguments will broadcast using `vectorbtpro.base.reshaping.broadcast` but keep original shape to utilize flexible indexing and to save memory. For defaults, see `vectorbtpro._settings.portfolio`. Those defaults are not used to fill NaN values after reindexing: vectorbt uses its own sensible defaults, which are usually NaN for floating arrays and default flags for integer arrays. Use `vectorbtpro.base.reshaping.BCO` with `fill_value` to override. !!! note When `call_seq` is not `CallSeqType.Auto`, at each timestamp, processing of the assets in a group goes strictly in order defined in `call_seq`. This order can't be changed dynamically. This has one big implication for this particular method: the last asset in the call stack cannot be processed until other assets are processed. This is the reason why rebalancing cannot work properly in this setting: one has to specify percentages for all assets beforehand and then tweak the processing order to sell to-be-sold assets first in order to release funds for to-be-bought assets. This can be automatically done by using `CallSeqType.Auto`. !!! hint All broadcastable arguments can be set per frame, series, row, column, or element. Usage: * Buy 10 units each tick: ```pycon >>> close = pd.Series([1, 2, 3, 4, 5]) >>> pf = vbt.Portfolio.from_orders(close, 10) >>> pf.assets 0 10.0 1 20.0 2 30.0 3 40.0 4 40.0 dtype: float64 >>> pf.cash 0 90.0 1 70.0 2 40.0 3 0.0 4 0.0 dtype: float64 ``` * Reverse each position by first closing it: ```pycon >>> size = [1, 0, -1, 0, 1] >>> pf = vbt.Portfolio.from_orders(close, size, size_type='targetpercent') >>> pf.assets 0 100.000000 1 0.000000 2 -66.666667 3 0.000000 4 26.666667 dtype: float64 >>> pf.cash 0 0.000000 1 200.000000 2 400.000000 3 133.333333 4 0.000000 dtype: float64 ``` * Regularly deposit cash at open and invest it within the same bar at close: ```pycon >>> close = pd.Series([1, 2, 3, 4, 5]) >>> cash_deposits = pd.Series([10., 0., 10., 0., 10.]) >>> pf = vbt.Portfolio.from_orders( ... close, ... size=cash_deposits, # invest the amount deposited ... size_type='value', ... cash_deposits=cash_deposits ... ) >>> pf.cash 0 100.0 1 100.0 2 100.0 3 100.0 4 100.0 dtype: float64 >>> pf.asset_flow 0 10.000000 1 0.000000 2 3.333333 3 0.000000 4 2.000000 dtype: float64 ``` * Equal-weighted portfolio as in `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb` example (it's more compact but has less control over execution): ```pycon >>> np.random.seed(42) >>> close = pd.DataFrame(np.random.uniform(1, 10, size=(5, 3))) >>> size = pd.Series(np.full(5, 1/3)) # each column 33.3% >>> size[1::2] = np.nan # skip every second tick >>> pf = vbt.Portfolio.from_orders( ... close, # acts both as reference and order price here ... size, ... size_type='targetpercent', ... direction='longonly', ... call_seq='auto', # first sell then buy ... group_by=True, # one group ... cash_sharing=True, # assets share the same cash ... fees=0.001, fixed_fees=1., slippage=0.001 # costs ... ) >>> pf.get_asset_value(group_by=False).vbt.plot().show() ``` ![](/assets/images/api/from_orders.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_orders.dark.svg#only-dark){: .iimg loading=lazy } * Test 10 random weight combinations: ```pycon >>> np.random.seed(42) >>> close = pd.DataFrame( ... np.random.uniform(1, 10, size=(5, 3)), ... columns=pd.Index(['a', 'b', 'c'], name='asset')) >>> # Generate random weight combinations >>> rand_weights = [] >>> for i in range(10): ... rand_weights.append(np.random.dirichlet(np.ones(close.shape[1]), size=1)[0]) >>> rand_weights [array([0.15474873, 0.27706078, 0.5681905 ]), array([0.30468598, 0.18545189, 0.50986213]), array([0.15780486, 0.36292607, 0.47926907]), array([0.25697713, 0.64902589, 0.09399698]), array([0.43310548, 0.53836359, 0.02853093]), array([0.78628605, 0.15716865, 0.0565453 ]), array([0.37186671, 0.42150531, 0.20662798]), array([0.22441579, 0.06348919, 0.71209502]), array([0.41619664, 0.09338007, 0.49042329]), array([0.01279537, 0.87770864, 0.10949599])] >>> # Bring close and rand_weights to the same shape >>> rand_weights = np.concatenate(rand_weights) >>> close = close.vbt.tile(10, keys=pd.Index(np.arange(10), name='weights_vector')) >>> size = vbt.broadcast_to(weights, close).copy() >>> size[1::2] = np.nan >>> size weights_vector 0 ... 9 asset a b c ... a b c 0 0.154749 0.277061 0.56819 ... 0.012795 0.877709 0.109496 1 NaN NaN NaN ... NaN NaN NaN 2 0.154749 0.277061 0.56819 ... 0.012795 0.877709 0.109496 3 NaN NaN NaN ... NaN NaN NaN 4 0.154749 0.277061 0.56819 ... 0.012795 0.877709 0.109496 [5 rows x 30 columns] >>> pf = vbt.Portfolio.from_orders( ... close, ... size, ... size_type='targetpercent', ... direction='longonly', ... call_seq='auto', ... group_by='weights_vector', # group by column level ... cash_sharing=True, ... fees=0.001, fixed_fees=1., slippage=0.001 ... ) >>> pf.total_return weights_vector 0 -0.294372 1 0.139207 2 -0.281739 3 0.041242 4 0.467566 5 0.829925 6 0.320672 7 -0.087452 8 0.376681 9 -0.702773 Name: total_return, dtype: float64 ``` """ if isinstance(close, FOPreparer): preparer = close prep_result = None elif isinstance(close, PFPrepResult): preparer = None prep_result = close else: local_kwargs = locals() local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]} del local_kwargs["kwargs"] del local_kwargs["cls"] del local_kwargs["return_preparer"] del local_kwargs["return_prep_result"] del local_kwargs["return_sim_out"] parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True) if parsed_data is not None: local_kwargs["data"] = parsed_data local_kwargs["close"] = None preparer = FOPreparer(**local_kwargs) if not return_preparer: preparer.set_seed() prep_result = None if return_preparer: return preparer if prep_result is None: prep_result = preparer.result if return_prep_result: return prep_result sim_out = prep_result.target_func(**prep_result.target_args) if return_sim_out: return sim_out return cls(order_records=sim_out, **prep_result.pf_args) @classmethod def from_signals( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin, FSPreparer, PFPrepResult], entries: tp.Optional[tp.ArrayLike] = None, exits: tp.Optional[tp.ArrayLike] = None, *, direction: tp.Optional[tp.ArrayLike] = None, long_entries: tp.Optional[tp.ArrayLike] = None, long_exits: tp.Optional[tp.ArrayLike] = None, short_entries: tp.Optional[tp.ArrayLike] = None, short_exits: tp.Optional[tp.ArrayLike] = None, adjust_func_nb: tp.Union[None, tp.PathLike, nb.AdjustFuncT] = None, adjust_args: tp.Args = (), signal_func_nb: tp.Union[None, tp.PathLike, nb.SignalFuncT] = None, signal_args: tp.ArgsLike = (), post_segment_func_nb: tp.Union[None, tp.PathLike, nb.PostSegmentFuncT] = None, post_segment_args: tp.ArgsLike = (), order_mode: bool = False, size: tp.Optional[tp.ArrayLike] = None, size_type: tp.Optional[tp.ArrayLike] = None, price: tp.Optional[tp.ArrayLike] = None, fees: tp.Optional[tp.ArrayLike] = None, fixed_fees: tp.Optional[tp.ArrayLike] = None, slippage: tp.Optional[tp.ArrayLike] = None, min_size: tp.Optional[tp.ArrayLike] = None, max_size: tp.Optional[tp.ArrayLike] = None, size_granularity: tp.Optional[tp.ArrayLike] = None, leverage: tp.Optional[tp.ArrayLike] = None, leverage_mode: tp.Optional[tp.ArrayLike] = None, reject_prob: tp.Optional[tp.ArrayLike] = None, price_area_vio_mode: tp.Optional[tp.ArrayLike] = None, allow_partial: tp.Optional[tp.ArrayLike] = None, raise_reject: tp.Optional[tp.ArrayLike] = None, log: tp.Optional[tp.ArrayLike] = None, val_price: tp.Optional[tp.ArrayLike] = None, accumulate: tp.Optional[tp.ArrayLike] = None, upon_long_conflict: tp.Optional[tp.ArrayLike] = None, upon_short_conflict: tp.Optional[tp.ArrayLike] = None, upon_dir_conflict: tp.Optional[tp.ArrayLike] = None, upon_opposite_entry: tp.Optional[tp.ArrayLike] = None, order_type: tp.Optional[tp.ArrayLike] = None, limit_delta: tp.Optional[tp.ArrayLike] = None, limit_tif: tp.Optional[tp.ArrayLike] = None, limit_expiry: tp.Optional[tp.ArrayLike] = None, limit_reverse: tp.Optional[tp.ArrayLike] = None, limit_order_price: tp.Optional[tp.ArrayLike] = None, upon_adj_limit_conflict: tp.Optional[tp.ArrayLike] = None, upon_opp_limit_conflict: tp.Optional[tp.ArrayLike] = None, use_stops: tp.Optional[bool] = None, stop_ladder: tp.Optional[bool] = None, sl_stop: tp.Optional[tp.ArrayLike] = None, tsl_stop: tp.Optional[tp.ArrayLike] = None, tsl_th: tp.Optional[tp.ArrayLike] = None, tp_stop: tp.Optional[tp.ArrayLike] = None, td_stop: tp.Optional[tp.ArrayLike] = None, dt_stop: tp.Optional[tp.ArrayLike] = None, stop_entry_price: tp.Optional[tp.ArrayLike] = None, stop_exit_price: tp.Optional[tp.ArrayLike] = None, stop_exit_type: tp.Optional[tp.ArrayLike] = None, stop_order_type: tp.Optional[tp.ArrayLike] = None, stop_limit_delta: tp.Optional[tp.ArrayLike] = None, upon_stop_update: tp.Optional[tp.ArrayLike] = None, upon_adj_stop_conflict: tp.Optional[tp.ArrayLike] = None, upon_opp_stop_conflict: tp.Optional[tp.ArrayLike] = None, delta_format: tp.Optional[tp.ArrayLike] = None, time_delta_format: tp.Optional[tp.ArrayLike] = None, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, init_cash: tp.Optional[tp.ArrayLike] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, cash_earnings: tp.Optional[tp.ArrayLike] = None, cash_dividends: tp.Optional[tp.ArrayLike] = None, cash_sharing: tp.Optional[bool] = None, from_ago: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, call_seq: tp.Optional[tp.ArrayLike] = None, attach_call_seq: tp.Optional[bool] = None, ffill_val_price: tp.Optional[bool] = None, update_value: tp.Optional[bool] = None, fill_pos_info: tp.Optional[bool] = None, save_state: tp.Optional[bool] = None, save_value: tp.Optional[bool] = None, save_returns: tp.Optional[bool] = None, skip_empty: tp.Optional[bool] = None, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = None, in_outputs: tp.Optional[tp.MappingLike] = None, seed: tp.Optional[int] = None, group_by: tp.GroupByLike = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, staticized: tp.StaticizedOption = None, bm_close: tp.Optional[tp.ArrayLike] = None, records: tp.Optional[tp.RecordsLike] = None, return_preparer: bool = False, return_prep_result: bool = False, return_sim_out: bool = False, **kwargs, ) -> PortfolioResultT: """Simulate portfolio from entry and exit signals. Supports the following modes: 1. `entries` and `exits`: Uses `vectorbtpro.portfolio.nb.from_signals.dir_signal_func_nb` as `signal_func_nb` if an adjustment function is provided (not cacheable), otherwise translates signals using `vectorbtpro.portfolio.nb.from_signals.dir_to_ls_signals_nb` then simulates statically (cacheable) 2. `entries` (acting as long), `exits` (acting as long), `short_entries`, and `short_exits`: Uses `vectorbtpro.portfolio.nb.from_signals.ls_signal_func_nb` as `signal_func_nb` if an adjustment function is provided (not cacheable), otherwise simulates statically (cacheable) 3. `order_mode=True` without signals: Uses `vectorbtpro.portfolio.nb.from_signals.order_signal_func_nb` as `signal_func_nb` (not cacheable) 4. `signal_func_nb` and `signal_args`: Custom signal function (not cacheable) Prepared by `vectorbtpro.portfolio.preparing.FSPreparer`. Args: close (array_like, OHLCDataMixin, FSPreparer, or PFPrepResult): See `Portfolio.from_orders`. entries (array_like of bool): Boolean array of entry signals. Defaults to True if all other signal arrays are not set, otherwise False. Will broadcast. * If `short_entries` and `short_exits` are not set: Acts as a long signal if `direction` is 'all' or 'longonly', otherwise short. * If `short_entries` or `short_exits` are set: Acts as `long_entries`. exits (array_like of bool): Boolean array of exit signals. Defaults to False. Will broadcast. * If `short_entries` and `short_exits` are not set: Acts as a short signal if `direction` is 'all' or 'longonly', otherwise long. * If `short_entries` or `short_exits` are set: Acts as `long_exits`. direction (Direction or array_like): See `Portfolio.from_orders`. Takes only effect if `short_entries` and `short_exits` are not set. long_entries (array_like of bool): Boolean array of long entry signals. Defaults to False. Will broadcast. long_exits (array_like of bool): Boolean array of long exit signals. Defaults to False. Will broadcast. short_entries (array_like of bool): Boolean array of short entry signals. Defaults to False. Will broadcast. short_exits (array_like of bool): Boolean array of short exit signals. Defaults to False. Will broadcast. adjust_func_nb (path_like or callable): User-defined function to adjust the current simulation state. Defaults to `vectorbtpro.portfolio.nb.from_signals.no_adjust_func_nb`. Passed as argument to `vectorbtpro.portfolio.nb.from_signals.dir_signal_func_nb`, `vectorbtpro.portfolio.nb.from_signals.ls_signal_func_nb`, and `vectorbtpro.portfolio.nb.from_signals.order_signal_func_nb`. Has no effect when using other signal functions. Can be a path to a module when using staticizing. adjust_args (tuple): Packed arguments passed to `adjust_func_nb`. signal_func_nb (path_like or callable): Function called to generate signals. See `vectorbtpro.portfolio.nb.from_signals.from_signal_func_nb`. Can be a path to a module when using staticizing. signal_args (tuple): Packed arguments passed to `signal_func_nb`. post_segment_func_nb (path_like or callable): Post-segment function. See `vectorbtpro.portfolio.nb.from_signals.from_signal_func_nb`. Can be a path to a module when using staticizing. post_segment_args (tuple): Packed arguments passed to `post_segment_func_nb`. order_mode (bool): Whether to simulate as orders without signals. size (float or array_like): See `Portfolio.from_orders`. !!! note Negative size is not allowed. You must express direction using signals. size_type (SizeType or array_like): See `Portfolio.from_orders`. Only `SizeType.Amount`, `SizeType.Value`, `SizeType.Percent(100)`, and `SizeType.ValuePercent(100)` are supported. Other modes such as target percentage are not compatible with signals since their logic may contradict the direction of the signal. !!! note `SizeType.Percent(100)` does not support position reversal. Switch to a single direction or use `OppositeEntryMode.Close` to close the position first. See warning in `Portfolio.from_orders`. price (array_like of float): See `Portfolio.from_orders`. fees (float or array_like): See `Portfolio.from_orders`. fixed_fees (float or array_like): See `Portfolio.from_orders`. slippage (float or array_like): See `Portfolio.from_orders`. min_size (float or array_like): See `Portfolio.from_orders`. max_size (float or array_like): See `Portfolio.from_orders`. Will be partially filled if exceeded. You might not be able to properly close the position if accumulation is enabled and `max_size` is too low. size_granularity (float or array_like): See `Portfolio.from_orders`. leverage (float or array_like): See `Portfolio.from_orders`. leverage_mode (LeverageMode or array_like): See `Portfolio.from_orders`. reject_prob (float or array_like): See `Portfolio.from_orders`. price_area_vio_mode (PriceAreaVioMode or array_like): See `Portfolio.from_orders`. allow_partial (bool or array_like): See `Portfolio.from_orders`. raise_reject (bool or array_like): See `Portfolio.from_orders`. log (bool or array_like): See `Portfolio.from_orders`. val_price (array_like of float): See `Portfolio.from_orders`. accumulate (bool, AccumulationMode or array_like): See `vectorbtpro.portfolio.enums.AccumulationMode`. If True, becomes 'both'. If False, becomes 'disabled'. Will broadcast. When enabled, `Portfolio.from_signals` behaves similarly to `Portfolio.from_orders`. upon_long_conflict (ConflictMode or array_like): Conflict mode for long signals. See `vectorbtpro.portfolio.enums.ConflictMode`. Will broadcast. upon_short_conflict (ConflictMode or array_like): Conflict mode for short signals. See `vectorbtpro.portfolio.enums.ConflictMode`. Will broadcast. upon_dir_conflict (DirectionConflictMode or array_like): See `vectorbtpro.portfolio.enums.DirectionConflictMode`. Will broadcast. upon_opposite_entry (OppositeEntryMode or array_like): See `vectorbtpro.portfolio.enums.OppositeEntryMode`. Will broadcast. order_type (OrderType or array_like): See `vectorbtpro.portfolio.enums.OrderType`. Only one active limit order is allowed at a time. limit_delta (float or array_like): Delta from `price` to build the limit price. Will broadcast. If NaN, `price` becomes the limit price. Otherwise, applied on top of `price` depending on the current direction: if the direction-aware size is positive (= buying), a positive delta will decrease the limit price; if the direction-aware size is negative (= selling), a positive delta will increase the limit price. Delta can be negative. Set an element to `np.nan` to disable. Use `delta_format` to specify the format. limit_tif (frequency_like or array_like): Time in force for limit signals. Will broadcast. Any frequency-like object is converted using `vectorbtpro.utils.datetime_.to_timedelta64`. Any array must either contain timedeltas or integers, and will be cast into integer format after broadcasting. If the object provided is of data type `object`, will be converted to timedelta automatically. Measured in the distance after the open time of the signal bar. If the expiration time happens in the middle of the current bar, we pessimistically assume that the order has been expired. The check is performed at the beginning of the bar, and the first check is performed at the next bar after the signal. For example, if the format is `TimeDeltaFormat.Rows`, 0 or 1 means the order must execute at the same bar or not at all; 2 means the order must execute at the same or next bar or not at all. Set an element to `-1` to disable. Use `time_delta_format` to specify the format. limit_expiry (frequency_like, datetime_like, or array_like): Expiration time. Will broadcast. Any frequency-like object is used to build a period index, such that each timestamp in the original index is pointing to the timestamp where the period ends. For example, providing "d" will make any limit order expire on the next day. Any array must either contain timestamps or integers (not timedeltas!), and will be cast into integer format after broadcasting. If the object provided is of data type `object`, will be converted to datetime and its timezone will be removed automatically (as done on the index). Behaves in a similar way as `limit_tif`. Set an element to `-1` or `pd.Timestamp.max` to disable. Use `time_delta_format` to specify the format. limit_reverse (bool or array_like): Whether to reverse the price hit detection. Will broadcast. If True, a buy/sell limit price will be checked against high/low (not low/high). Also, the limit delta will be applied above/below (not below/above) the initial price. limit_order_price (LimitOrderPrice or array_like): See `vectorbtpro.portfolio.enums.LimitOrderPrice`. Will broadcast. If provided on per-element basis, gets applied upon order creation. If a positive value is provided, used directly as a price, otherwise used as an enumerated value. upon_adj_limit_conflict (PendingConflictMode or array_like): Conflict mode for limit and user-defined signals of adjacent sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast. upon_opp_limit_conflict (PendingConflictMode or array_like): Conflict mode for limit and user-defined signals of opposite sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast. use_stops (bool): Whether to use stops. Defaults to None, which becomes True if any of the stops are not NaN or the adjustment function is not the default one. Disable this to make simulation a bit faster for simple use cases. stop_ladder (bool or StopLadderMode): Whether and which kind of stop laddering to use. See `vectorbtpro.portfolio.enums.StopLadderMode`. If so, rows in the supplied arrays will become ladder steps. Make sure that they are increasing. If one column should have less steps, pad it with NaN for price-based stops and -1 for time-based stops. Rows in each array can be of an arbitrary length but columns must broadcast against the number of columns in the data. Applied on all stop types. sl_stop (array_like of float): Stop loss. Will broadcast. Set an element to `np.nan` to disable. Use `delta_format` to specify the format. tsl_stop (array_like of float): Trailing stop loss. Will broadcast. Set an element to `np.nan` to disable. Use `delta_format` to specify the format. tsl_th (array_like of float): Take profit threshold for the trailing stop loss. Will broadcast. Set an element to `np.nan` to disable. Use `delta_format` to specify the format. tp_stop (array_like of float): Take profit. Will broadcast. Set an element to `np.nan` to disable. Use `delta_format` to specify the format. td_stop (frequency_like or array_like): Timedelta-stop. Will broadcast. Set an element to `-1` to disable. Use `time_delta_format` to specify the format. dt_stop (frequency_like, datetime_like, or array_like): Datetime-stop. Will broadcast. Set an element to `-1` to disable. Use `time_delta_format` to specify the format. stop_entry_price (StopEntryPrice or array_like): See `vectorbtpro.portfolio.enums.StopEntryPrice`. Will broadcast. If provided on per-element basis, gets applied upon entry. If a positive value is provided, used directly as a price, otherwise used as an enumerated value. stop_exit_price (StopExitPrice or array_like): See `vectorbtpro.portfolio.enums.StopExitPrice`. Will broadcast. If provided on per-element basis, gets applied upon entry. If a positive value is provided, used directly as a price, otherwise used as an enumerated value. stop_exit_type (StopExitType or array_like): See `vectorbtpro.portfolio.enums.StopExitType`. Will broadcast. If provided on per-element basis, gets applied upon entry. stop_order_type (OrderType or array_like): Similar to `order_type` but for stop orders. Will broadcast. If provided on per-element basis, gets applied upon entry. stop_limit_delta (float or array_like): Similar to `limit_delta` but for stop orders. Will broadcast. upon_stop_update (StopUpdateMode or array_like): See `vectorbtpro.portfolio.enums.StopUpdateMode`. Will broadcast. Only has effect if accumulation is enabled. If provided on per-element basis, gets applied upon repeated entry. upon_adj_stop_conflict (PendingConflictMode or array_like): Conflict mode for stop and user-defined signals of adjacent sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast. upon_opp_stop_conflict (PendingConflictMode or array_like): Conflict mode for stop and user-defined signals of opposite sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast. delta_format (DeltaFormat or array_like): See `vectorbtpro.portfolio.enums.DeltaFormat`. Will broadcast. time_delta_format (TimeDeltaFormat or array_like): See `vectorbtpro.portfolio.enums.TimeDeltaFormat`. Will broadcast. open (array_like of float): See `Portfolio.from_orders`. For stop signals, `np.nan` gets replaced by `close`. high (array_like of float): See `Portfolio.from_orders`. For stop signals, `np.nan` replaced by the maximum out of `open` and `close`. low (array_like of float): See `Portfolio.from_orders`. For stop signals, `np.nan` replaced by the minimum out of `open` and `close`. init_cash (InitCashMode, float or array_like): See `Portfolio.from_orders`. init_position (float or array_like): See `Portfolio.from_orders`. init_price (float or array_like): See `Portfolio.from_orders`. cash_deposits (float or array_like): See `Portfolio.from_orders`. cash_earnings (float or array_like): See `Portfolio.from_orders`. cash_dividends (float or array_like): See `Portfolio.from_orders`. cash_sharing (bool): See `Portfolio.from_orders`. from_ago (int or array_like): See `Portfolio.from_orders`. Take effect only for user-defined signals, not for stop signals. sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive). Can be "auto", which will be substituted by the index of the first signal across long and short entries and long and short exits. sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive). Can be "auto", which will be substituted by the index of the last signal across long and short entries and long and short exits. call_seq (CallSeqType or array_like): See `Portfolio.from_orders`. attach_call_seq (bool): See `Portfolio.from_orders`. ffill_val_price (bool): See `Portfolio.from_orders`. update_value (bool): See `Portfolio.from_orders`. fill_pos_info (bool): fill_pos_info (bool): Whether to fill position record. Disable this to make simulation faster for simple use cases. save_state (bool): See `Portfolio.from_orders`. save_value (bool): See `Portfolio.from_orders`. save_returns (bool): See `Portfolio.from_orders`. skip_empty (bool): See `Portfolio.from_orders` max_order_records (int): See `Portfolio.from_orders`. max_log_records (int): See `Portfolio.from_orders`. in_outputs (mapping_like): Mapping with in-output objects. Only for flexible mode. Will be available via `Portfolio.in_outputs` as a named tuple. To substitute `Portfolio` attributes, provide already broadcasted and grouped objects, for example, by using `broadcast_named_args` and templates. Also see `Portfolio.in_outputs_indexing_func` on how in-output objects are indexed. When chunking, make sure to provide the chunk taking specification and the merging function. See `vectorbtpro.portfolio.chunking.merge_sim_outs`. !!! note When using Numba below 0.54, `in_outputs` cannot be a mapping, but must be a named tuple defined globally so Numba can introspect its attributes for pickling. seed (int): See `Portfolio.from_orders`. group_by (any): See `Portfolio.from_orders`. broadcast_named_args (dict): Dictionary with named arguments to broadcast. You can then pass argument names wrapped with `vectorbtpro.utils.template.Rep` and this method will substitute them by their corresponding broadcasted objects. broadcast_kwargs (dict): See `Portfolio.from_orders`. template_context (mapping): Context used to substitute templates in arguments. jitted (any): See `Portfolio.from_orders`. chunked (any): See `Portfolio.from_orders`. staticized (bool, dict, hashable, or callable): Keyword arguments or task id for staticizing. If True or dictionary, will be passed as keyword arguments to `vectorbtpro.utils.cutting.cut_and_save_func` to save a cacheable version of the simulator to a file. If a hashable or callable, will be used as a task id of an already registered jittable and chunkable simulator. Dictionary allows additional options `override` and `reload` to override and reload an already existing module respectively. bm_close (array_like): See `Portfolio.from_orders`. records (array_like): See `Portfolio.from_orders`. return_preparer (bool): See `Portfolio.from_orders`. return_prep_result (bool): See `Portfolio.from_orders`. return_sim_out (bool): See `Portfolio.from_orders`. **kwargs: Keyword arguments passed to the `Portfolio` constructor. All broadcastable arguments will broadcast using `vectorbtpro.base.reshaping.broadcast` but keep original shape to utilize flexible indexing and to save memory. For defaults, see `vectorbtpro._settings.portfolio`. Those defaults are not used to fill NaN values after reindexing: vectorbt uses its own sensible defaults, which are usually NaN for floating arrays and default flags for integer arrays. Use `vectorbtpro.base.reshaping.BCO` with `fill_value` to override. Also see notes and hints for `Portfolio.from_orders`. Usage: * By default, if all signal arrays are None, `entries` becomes True, which opens a position at the very first tick and does nothing else: ```pycon >>> close = pd.Series([1, 2, 3, 4, 5]) >>> pf = vbt.Portfolio.from_signals(close, size=1) >>> pf.asset_flow 0 1.0 1 0.0 2 0.0 3 0.0 4 0.0 dtype: float64 ``` * Entry opens long, exit closes long: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1, ... direction='longonly' ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 0.0 3 -1.0 4 0.0 dtype: float64 >>> # Using direction-aware arrays instead of `direction` >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), # long_entries ... exits=pd.Series([False, False, True, True, True]), # long_exits ... short_entries=False, ... short_exits=False, ... size=1 ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 0.0 3 -1.0 4 0.0 dtype: float64 ``` Notice how both `short_entries` and `short_exits` are provided as constants - as any other broadcastable argument, they are treated as arrays where each element is False. * Entry opens short, exit closes short: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1, ... direction='shortonly' ... ) >>> pf.asset_flow 0 -1.0 1 0.0 2 0.0 3 1.0 4 0.0 dtype: float64 >>> # Using direction-aware arrays instead of `direction` >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=False, # long_entries ... exits=False, # long_exits ... short_entries=pd.Series([True, True, True, False, False]), ... short_exits=pd.Series([False, False, True, True, True]), ... size=1 ... ) >>> pf.asset_flow 0 -1.0 1 0.0 2 0.0 3 1.0 4 0.0 dtype: float64 ``` * Entry opens long and closes short, exit closes long and opens short: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1, ... direction='both' ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 0.0 3 -2.0 4 0.0 dtype: float64 >>> # Using direction-aware arrays instead of `direction` >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), # long_entries ... exits=False, # long_exits ... short_entries=pd.Series([False, False, True, True, True]), ... short_exits=False, ... size=1 ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 0.0 3 -2.0 4 0.0 dtype: float64 ``` * More complex signal combinations are best expressed using direction-aware arrays. For example, ignore opposite signals as long as the current position is open: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries =pd.Series([True, False, False, False, False]), # long_entries ... exits =pd.Series([False, False, True, False, False]), # long_exits ... short_entries=pd.Series([False, True, False, True, False]), ... short_exits =pd.Series([False, False, False, False, True]), ... size=1, ... upon_opposite_entry='ignore' ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 -1.0 3 -1.0 4 1.0 dtype: float64 ``` * First opposite signal closes the position, second one opens a new position: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1, ... direction='both', ... upon_opposite_entry='close' ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 0.0 3 -1.0 4 -1.0 dtype: float64 ``` * If both long entry and exit signals are True (a signal conflict), choose exit: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1., ... direction='longonly', ... upon_long_conflict='exit') >>> pf.asset_flow 0 1.0 1 0.0 2 -1.0 3 0.0 4 0.0 dtype: float64 ``` * If both long entry and short entry signal are True (a direction conflict), choose short: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1., ... direction='both', ... upon_dir_conflict='short') >>> pf.asset_flow 0 1.0 1 0.0 2 -2.0 3 0.0 4 0.0 dtype: float64 ``` !!! note Remember that when direction is set to 'both', entries become `long_entries` and exits become `short_entries`, so this becomes a conflict of directions rather than signals. * If there are both signal and direction conflicts: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=True, # long_entries ... exits=True, # long_exits ... short_entries=True, ... short_exits=True, ... size=1, ... upon_long_conflict='entry', ... upon_short_conflict='entry', ... upon_dir_conflict='short' ... ) >>> pf.asset_flow 0 -1.0 1 0.0 2 0.0 3 0.0 4 0.0 dtype: float64 ``` * Turn on accumulation of signals. Entry means long order, exit means short order (acts similar to `from_orders`): ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1., ... direction='both', ... accumulate=True) >>> pf.asset_flow 0 1.0 1 1.0 2 0.0 3 -1.0 4 -1.0 dtype: float64 ``` * Allow increasing a position (of any direction), deny decreasing a position: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... size=1., ... direction='both', ... accumulate='addonly') >>> pf.asset_flow 0 1.0 << open a long position 1 1.0 << add to the position 2 0.0 3 -3.0 << close and open a short position 4 -1.0 << add to the position dtype: float64 ``` * Test multiple parameters via regular broadcasting: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... direction=[list(Direction)], ... broadcast_kwargs=dict(columns_from=pd.Index(vbt.pf_enums.Direction._fields, name='direction'))) >>> pf.asset_flow direction LongOnly ShortOnly Both 0 100.0 -100.0 100.0 1 0.0 0.0 0.0 2 0.0 0.0 0.0 3 -100.0 50.0 -200.0 4 0.0 0.0 0.0 ``` * Test multiple parameters via `vectorbtpro.base.reshaping.BCO`: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, ... entries=pd.Series([True, True, True, False, False]), ... exits=pd.Series([False, False, True, True, True]), ... direction=vbt.Param(Direction)) >>> pf.asset_flow direction LongOnly ShortOnly Both 0 100.0 -100.0 100.0 1 0.0 0.0 0.0 2 0.0 0.0 0.0 3 -100.0 50.0 -200.0 4 0.0 0.0 0.0 ``` * Set risk/reward ratio by passing trailing stop loss and take profit thresholds: ```pycon >>> close = pd.Series([10, 11, 12, 11, 10, 9]) >>> entries = pd.Series([True, False, False, False, False, False]) >>> exits = pd.Series([False, False, False, False, False, True]) >>> pf = vbt.Portfolio.from_signals( ... close, entries, exits, ... tsl_stop=0.1, tp_stop=0.2) # take profit hit >>> pf.asset_flow 0 10.0 1 0.0 2 -10.0 3 0.0 4 0.0 5 0.0 dtype: float64 >>> pf = vbt.Portfolio.from_signals( ... close, entries, exits, ... tsl_stop=0.1, tp_stop=0.3) # trailing stop loss hit >>> pf.asset_flow 0 10.0 1 0.0 2 0.0 3 0.0 4 -10.0 5 0.0 dtype: float64 >>> pf = vbt.Portfolio.from_signals( ... close, entries, exits, ... tsl_stop=np.inf, tp_stop=np.inf) # nothing hit, exit as usual >>> pf.asset_flow 0 10.0 1 0.0 2 0.0 3 0.0 4 0.0 5 -10.0 dtype: float64 ``` * Test different stop combinations: ```pycon >>> pf = vbt.Portfolio.from_signals( ... close, entries, exits, ... tsl_stop=vbt.Param([0.1, 0.2]), ... tp_stop=vbt.Param([0.2, 0.3]) ... ) >>> pf.asset_flow tsl_stop 0.1 0.2 tp_stop 0.2 0.3 0.2 0.3 0 10.0 10.0 10.0 10.0 1 0.0 0.0 0.0 0.0 2 -10.0 0.0 -10.0 0.0 3 0.0 0.0 0.0 0.0 4 0.0 -10.0 0.0 0.0 5 0.0 0.0 0.0 -10.0 ``` This works because `pd.Index` automatically translates into `vectorbtpro.base.reshaping.BCO` with `product` set to True. * We can implement our own stop loss or take profit, or adjust the existing one at each time step. Let's implement [stepped stop-loss](https://www.freqtrade.io/en/stable/strategy-advanced/#stepped-stoploss): ```pycon >>> @njit ... def adjust_func_nb(c): ... val_price_now = c.last_val_price[c.col] ... tsl_init_price = c.last_tsl_info["init_price"][c.col] ... current_profit = (val_price_now - tsl_init_price) / tsl_init_price ... if current_profit >= 0.40: ... c.last_tsl_info["stop"][c.col] = 0.25 ... elif current_profit >= 0.25: ... c.last_tsl_info["stop"][c.col] = 0.15 ... elif current_profit >= 0.20: ... c.last_tsl_info["stop"][c.col] = 0.07 >>> close = pd.Series([10, 11, 12, 11, 10]) >>> pf = vbt.Portfolio.from_signals(close, adjust_func_nb=adjust_func_nb) >>> pf.asset_flow 0 10.0 1 0.0 2 0.0 3 -10.0 # 7% from 12 hit 4 11.16 dtype: float64 ``` * Sometimes there is a need to provide or transform signals dynamically. For this, we can implement a custom signal function `signal_func_nb`. For example, let's implement a signal function that takes two numerical arrays - long and short one - and transforms them into 4 direction-aware boolean arrays that vectorbt understands: ```pycon >>> @njit ... def signal_func_nb(c, long_num_arr, short_num_arr): ... long_num = vbt.pf_nb.select_nb(c, long_num_arr) ... short_num = vbt.pf_nb.select_nb(c, short_num_arr) ... is_long_entry = long_num > 0 ... is_long_exit = long_num < 0 ... is_short_entry = short_num > 0 ... is_short_exit = short_num < 0 ... return is_long_entry, is_long_exit, is_short_entry, is_short_exit >>> pf = vbt.Portfolio.from_signals( ... pd.Series([1, 2, 3, 4, 5]), ... signal_func_nb=signal_func_nb, ... signal_args=(vbt.Rep('long_num_arr'), vbt.Rep('short_num_arr')), ... broadcast_named_args=dict( ... long_num_arr=pd.Series([1, 0, -1, 0, 0]), ... short_num_arr=pd.Series([0, 1, 0, 1, -1]) ... ), ... size=1, ... upon_opposite_entry='ignore' ... ) >>> pf.asset_flow 0 1.0 1 0.0 2 -1.0 3 -1.0 4 1.0 dtype: float64 ``` Passing both arrays as `broadcast_named_args` broadcasts them internally as any other array, so we don't have to worry about their dimensions every time we change our data. """ if isinstance(close, FSPreparer): preparer = close prep_result = None elif isinstance(close, PFPrepResult): preparer = None prep_result = close else: local_kwargs = locals() local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]} del local_kwargs["kwargs"] del local_kwargs["cls"] del local_kwargs["return_preparer"] del local_kwargs["return_prep_result"] del local_kwargs["return_sim_out"] parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True) if parsed_data is not None: local_kwargs["data"] = parsed_data local_kwargs["close"] = None preparer = FSPreparer(**local_kwargs) if not return_preparer: preparer.set_seed() prep_result = None if return_preparer: return preparer if prep_result is None: prep_result = preparer.result if return_prep_result: return prep_result sim_out = prep_result.target_func(**prep_result.target_args) if return_sim_out: return sim_out return cls(order_records=sim_out, **prep_result.pf_args) @classmethod def from_holding( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin], direction: tp.Optional[int] = None, at_first_valid_in: tp.Optional[str] = "close", close_at_end: tp.Optional[bool] = None, dynamic_mode: bool = False, **kwargs, ) -> PortfolioResultT: """Simulate portfolio from plain holding using signals. If `close_at_end` is True, will place an opposite signal at the very end. `**kwargs` are passed to the class method `Portfolio.from_signals`.""" from vectorbtpro._settings import settings portfolio_cfg = settings["portfolio"] if direction is None: direction = portfolio_cfg["hold_direction"] direction = map_enum_fields(direction, enums.Direction) if not checks.is_int(direction): raise TypeError("Direction must be a scalar") if close_at_end is None: close_at_end = portfolio_cfg["close_at_end"] if dynamic_mode: def _substitute_signal_args(preparer): return ( direction, close_at_end, *((preparer.adjust_func_nb,) if preparer.staticized is None else ()), preparer.adjust_args, ) return cls.from_signals( close, signal_func_nb=nb.holding_enex_signal_func_nb, signal_args=RepFunc(_substitute_signal_args), accumulate=False, **kwargs, ) def _entries(wrapper, new_objs): if at_first_valid_in is None: entries = np.full((wrapper.shape_2d[0], 1), False) entries[0] = True return entries ts = new_objs[at_first_valid_in] valid_index = generic_nb.first_valid_index_nb(ts) if (valid_index == -1).all(): return np.array([[False]]) if (valid_index == 0).all(): entries = np.full((wrapper.shape_2d[0], 1), False) entries[0] = True return entries entries = np.full(wrapper.shape_2d, False) entries[valid_index, np.arange(wrapper.shape_2d[1])] = True return entries def _exits(wrapper): if close_at_end: exits = np.full((wrapper.shape_2d[0], 1), False) exits[-1] = True else: exits = np.array([[False]]) return exits return cls.from_signals( close, entries=RepFunc(_entries), exits=RepFunc(_exits), direction=direction, accumulate=False, **kwargs, ) @classmethod def from_random_signals( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin], n: tp.Optional[tp.ArrayLike] = None, prob: tp.Optional[tp.ArrayLike] = None, entry_prob: tp.Optional[tp.ArrayLike] = None, exit_prob: tp.Optional[tp.ArrayLike] = None, param_product: bool = False, seed: tp.Optional[int] = None, run_kwargs: tp.KwargsLike = None, **kwargs, ) -> PortfolioResultT: """Simulate portfolio from random entry and exit signals. Generates signals based either on the number of signals `n` or the probability of encountering a signal `prob`. * If `n` is set, see `vectorbtpro.signals.generators.randnx.RANDNX`. * If `prob` is set, see `vectorbtpro.signals.generators.rprobnx.RPROBNX`. Based on `Portfolio.from_signals`. !!! note To generate random signals, the shape of `close` is used. Broadcasting with other arrays happens after the generation. Usage: * Test multiple combinations of random entries and exits: ```pycon >>> close = pd.Series([1, 2, 3, 4, 5]) >>> pf = vbt.Portfolio.from_random_signals(close, n=[2, 1, 0], seed=42) >>> pf.orders.count() randnx_n 2 4 1 2 0 0 Name: count, dtype: int64 ``` * Test the Cartesian product of entry and exit encounter probabilities: ```pycon >>> pf = vbt.Portfolio.from_random_signals( ... close, ... entry_prob=[0, 0.5, 1], ... exit_prob=[0, 0.5, 1], ... param_product=True, ... seed=42) >>> pf.orders.count() rprobnx_entry_prob rprobnx_exit_prob 0.0 0.0 0 0.5 0 1.0 0 0.5 0.0 1 0.5 4 1.0 3 1.0 0.0 1 0.5 4 1.0 5 Name: count, dtype: int64 ``` """ from vectorbtpro._settings import settings portfolio_cfg = settings["portfolio"] parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True) if parsed_data is not None: data = parsed_data close = data.close if close is None: raise ValueError("Column for close couldn't be found in data") close_wrapper = data.symbol_wrapper else: close = to_pd_array(close) close_wrapper = ArrayWrapper.from_obj(close) data = close if entry_prob is None: entry_prob = prob if exit_prob is None: exit_prob = prob if seed is None: seed = portfolio_cfg["seed"] if run_kwargs is None: run_kwargs = {} if n is not None and (entry_prob is not None or exit_prob is not None): raise ValueError("Must provide either n or entry_prob and exit_prob") if n is not None: from vectorbtpro.signals.generators.randnx import RANDNX rand = RANDNX.run( n=n, input_shape=close_wrapper.shape, input_index=close_wrapper.index, input_columns=close_wrapper.columns, seed=seed, **run_kwargs, ) entries = rand.entries exits = rand.exits elif entry_prob is not None and exit_prob is not None: from vectorbtpro.signals.generators.rprobnx import RPROBNX rprobnx = RPROBNX.run( entry_prob=entry_prob, exit_prob=exit_prob, param_product=param_product, input_shape=close_wrapper.shape, input_index=close_wrapper.index, input_columns=close_wrapper.columns, seed=seed, **run_kwargs, ) entries = rprobnx.entries exits = rprobnx.exits else: raise ValueError("Must provide at least n or entry_prob and exit_prob") return cls.from_signals(data, entries, exits, seed=seed, **kwargs) @classmethod def from_optimizer( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin], optimizer: PortfolioOptimizer, pf_method: str = "from_orders", squeeze_groups: bool = True, dropna: tp.Optional[str] = None, fill_value: tp.Scalar = np.nan, size_type: tp.ArrayLike = "targetpercent", direction: tp.Optional[tp.ArrayLike] = None, cash_sharing: tp.Optional[bool] = True, call_seq: tp.Optional[tp.ArrayLike] = "auto", group_by: tp.GroupByLike = None, **kwargs, ) -> PortfolioResultT: """Build portfolio from an optimizer of type `vectorbtpro.portfolio.pfopt.base.PortfolioOptimizer`. Uses `Portfolio.from_orders` as the base simulation method. The size type is 'targetpercent'. If there are positive and negative values, the direction is automatically set to 'both', otherwise to 'longonly' for positive-only and `shortonly` for negative-only values. Also, the cash sharing is set to True, the call sequence is set to 'auto', and the grouper is set to the grouper of the optimizer by default. Usage: ```pycon >>> close = pd.DataFrame({ ... "MSFT": [1, 2, 3, 4, 5], ... "GOOG": [5, 4, 3, 2, 1], ... "AAPL": [1, 2, 3, 2, 1] ... }, index=pd.date_range(start="2020-01-01", periods=5)) >>> pfo = vbt.PortfolioOptimizer.from_random( ... close.vbt.wrapper, ... every="2D", ... seed=42 ... ) >>> pfo.fill_allocations() MSFT GOOG AAPL 2020-01-01 0.182059 0.462129 0.355812 2020-01-02 NaN NaN NaN 2020-01-03 0.657381 0.171323 0.171296 2020-01-04 NaN NaN NaN 2020-01-05 0.038078 0.567845 0.394077 >>> pf = vbt.Portfolio.from_optimizer(close, pfo) >>> pf.get_asset_value(group_by=False).vbt / pf.value alloc_group group MSFT GOOG AAPL 2020-01-01 0.182059 0.462129 0.355812 << rebalanced 2020-01-02 0.251907 0.255771 0.492322 2020-01-03 0.657381 0.171323 0.171296 << rebalanced 2020-01-04 0.793277 0.103369 0.103353 2020-01-05 0.038078 0.567845 0.394077 << rebalanced ``` """ size = optimizer.fill_allocations(squeeze_groups=squeeze_groups, dropna=dropna, fill_value=fill_value) if direction is None: pos_size_any = (size.values > 0).any() neg_size_any = (size.values < 0).any() if pos_size_any and neg_size_any: direction = "both" elif pos_size_any: direction = "longonly" else: direction = "shortonly" size = size.abs() if group_by is None: def _substitute_group_by(index): columns = optimizer.wrapper.columns if squeeze_groups and optimizer.wrapper.grouped_ndim == 1: columns = columns.droplevel(level=0) if not index.equals(columns): if "symbol" in index.names: return ExceptLevel("symbol") raise ValueError("Column hierarchy has changed. Disable squeeze_groups and provide group_by.") return optimizer.wrapper.grouper.group_by group_by = RepFunc(_substitute_group_by) if pf_method.lower() == "from_orders": return cls.from_orders( close, size=size, size_type=size_type, direction=direction, cash_sharing=cash_sharing, call_seq=call_seq, group_by=group_by, **kwargs, ) elif pf_method.lower() == "from_signals": return cls.from_signals( close, order_mode=True, size=size, size_type=size_type, direction=direction, accumulate=True, cash_sharing=cash_sharing, call_seq=call_seq, group_by=group_by, **kwargs, ) else: raise ValueError(f"Invalid pf_method: '{pf_method}'") @classmethod def from_order_func( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin, FOFPreparer, PFPrepResult], *, init_cash: tp.Optional[tp.ArrayLike] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, cash_earnings: tp.Optional[tp.ArrayLike] = None, cash_sharing: tp.Optional[bool] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, call_seq: tp.Optional[tp.ArrayLike] = None, attach_call_seq: tp.Optional[bool] = None, segment_mask: tp.Optional[tp.ArrayLike] = None, call_pre_segment: tp.Optional[bool] = None, call_post_segment: tp.Optional[bool] = None, pre_sim_func_nb: tp.Optional[nb.PreSimFuncT] = None, pre_sim_args: tp.Args = (), post_sim_func_nb: tp.Optional[nb.PostSimFuncT] = None, post_sim_args: tp.Args = (), pre_group_func_nb: tp.Optional[nb.PreGroupFuncT] = None, pre_group_args: tp.Args = (), post_group_func_nb: tp.Optional[nb.PostGroupFuncT] = None, post_group_args: tp.Args = (), pre_row_func_nb: tp.Optional[nb.PreRowFuncT] = None, pre_row_args: tp.Args = (), post_row_func_nb: tp.Optional[nb.PostRowFuncT] = None, post_row_args: tp.Args = (), pre_segment_func_nb: tp.Optional[nb.PreSegmentFuncT] = None, pre_segment_args: tp.Args = (), post_segment_func_nb: tp.Optional[nb.PostSegmentFuncT] = None, post_segment_args: tp.Args = (), order_func_nb: tp.Optional[nb.OrderFuncT] = None, order_args: tp.Args = (), flex_order_func_nb: tp.Optional[nb.FlexOrderFuncT] = None, flex_order_args: tp.Args = (), post_order_func_nb: tp.Optional[nb.PostOrderFuncT] = None, post_order_args: tp.Args = (), open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, ffill_val_price: tp.Optional[bool] = None, update_value: tp.Optional[bool] = None, fill_pos_info: tp.Optional[bool] = None, track_value: tp.Optional[bool] = None, row_wise: tp.Optional[bool] = None, max_order_records: tp.Optional[int] = None, max_log_records: tp.Optional[int] = None, in_outputs: tp.Optional[tp.MappingLike] = None, seed: tp.Optional[int] = None, group_by: tp.GroupByLike = None, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, keep_inout_flex: tp.Optional[bool] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, staticized: tp.StaticizedOption = None, bm_close: tp.Optional[tp.ArrayLike] = None, records: tp.Optional[tp.RecordsLike] = None, return_preparer: bool = False, return_prep_result: bool = False, return_sim_out: bool = False, **kwargs, ) -> PortfolioResultT: """Build portfolio from a custom order function. !!! hint See `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb` for illustrations and argument definitions. For more details on individual simulation functions: * `order_func_nb`: See `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb` * `order_func_nb` and `row_wise`: See `vectorbtpro.portfolio.nb.from_order_func.from_order_func_rw_nb` * `flex_order_func_nb`: See `vectorbtpro.portfolio.nb.from_order_func.from_flex_order_func_nb` * `flex_order_func_nb` and `row_wise`: See `vectorbtpro.portfolio.nb.from_order_func.from_flex_order_func_rw_nb` Prepared by `vectorbtpro.portfolio.preparing.FOFPreparer`. Args: close (array_like, OHLCDataMixin, FOFPreparer, or PFPrepResult): Latest asset price at each time step. Will broadcast. If an instance of `vectorbtpro.data.base.OHLCDataMixin`, will extract the open, high, low, and close price. Used for calculating unrealized PnL and portfolio value. init_cash (InitCashMode, float or array_like): See `Portfolio.from_orders`. init_position (float or array_like): See `Portfolio.from_orders`. init_price (float or array_like): See `Portfolio.from_orders`. cash_deposits (float or array_like): See `Portfolio.from_orders`. cash_earnings (float or array_like): See `Portfolio.from_orders`. cash_sharing (bool): Whether to share cash within the same group. If `group_by` is None, `group_by` becomes True to form a single group with cash sharing. sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive). sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive). call_seq (CallSeqType or array_like): Default sequence of calls per row and group. * Use `vectorbtpro.portfolio.enums.CallSeqType` to select a sequence type. * Set to array to specify custom sequence. Will not broadcast. !!! note CallSeqType.Auto must be implemented manually. Use `vectorbtpro.portfolio.nb.from_order_func.sort_call_seq_1d_nb` or `vectorbtpro.portfolio.nb.from_order_func.sort_call_seq_out_1d_nb` in `pre_segment_func_nb`. attach_call_seq (bool): See `Portfolio.from_orders`. segment_mask (int or array_like of bool): Mask of whether a particular segment should be executed. Supplying an integer will activate every n-th row. Supplying a boolean or an array of boolean will broadcast to the number of rows and groups. Does not broadcast together with `close` and `broadcast_named_args`, only against the final shape. call_pre_segment (bool): Whether to call `pre_segment_func_nb` regardless of `segment_mask`. call_post_segment (bool): Whether to call `post_segment_func_nb` regardless of `segment_mask`. pre_sim_func_nb (callable): Function called before simulation. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`. pre_sim_args (tuple): Packed arguments passed to `pre_sim_func_nb`. post_sim_func_nb (callable): Function called after simulation. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`. post_sim_args (tuple): Packed arguments passed to `post_sim_func_nb`. pre_group_func_nb (callable): Function called before each group. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`. Called only if `row_wise` is False. pre_group_args (tuple): Packed arguments passed to `pre_group_func_nb`. post_group_func_nb (callable): Function called after each group. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`. Called only if `row_wise` is False. post_group_args (tuple): Packed arguments passed to `post_group_func_nb`. pre_row_func_nb (callable): Function called before each row. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`. Called only if `row_wise` is True. pre_row_args (tuple): Packed arguments passed to `pre_row_func_nb`. post_row_func_nb (callable): Function called after each row. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`. Called only if `row_wise` is True. post_row_args (tuple): Packed arguments passed to `post_row_func_nb`. pre_segment_func_nb (callable): Function called before each segment. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`. pre_segment_args (tuple): Packed arguments passed to `pre_segment_func_nb`. post_segment_func_nb (callable): Function called after each segment. Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`. post_segment_args (tuple): Packed arguments passed to `post_segment_func_nb`. order_func_nb (callable): Order generation function. order_args: Packed arguments passed to `order_func_nb`. flex_order_func_nb (callable): Flexible order generation function. flex_order_args: Packed arguments passed to `flex_order_func_nb`. post_order_func_nb (callable): Callback that is called after the order has been processed. post_order_args (tuple): Packed arguments passed to `post_order_func_nb`. open (array_like of float): See `Portfolio.from_orders`. high (array_like of float): See `Portfolio.from_orders`. low (array_like of float): See `Portfolio.from_orders`. ffill_val_price (bool): Whether to track valuation price only if it's known. Otherwise, unknown `close` will lead to NaN in valuation price at the next timestamp. update_value (bool): Whether to update group value after each filled order. fill_pos_info (bool): Whether to fill position record. Disable this to make simulation faster for simple use cases. track_value (bool): Whether to track value metrics such as the current valuation price, value, and return. Disable this to make simulation faster for simple use cases. row_wise (bool): Whether to iterate over rows rather than columns/groups. max_order_records (int): The max number of order records expected to be filled at each column. Defaults to the number of rows in the broadcasted shape. Set to a lower number if you run out of memory, to 0 to not fill, and to a higher number if there are more than one order expected at each timestamp. max_log_records (int): The max number of log records expected to be filled at each column. Defaults to the number of rows in the broadcasted shape. Set to a lower number if you run out of memory, to 0 to not fill, and to a higher number if there are more than one order expected at each timestamp. in_outputs (mapping_like): Mapping with in-output objects. Will be available via `Portfolio.in_outputs` as a named tuple. To substitute `Portfolio` attributes, provide already broadcasted and grouped objects, for example, by using `broadcast_named_args` and templates. Also see `Portfolio.in_outputs_indexing_func` on how in-output objects are indexed. When chunking, make sure to provide the chunk taking specification and the merging function. See `vectorbtpro.portfolio.chunking.merge_sim_outs`. !!! note When using Numba below 0.54, `in_outputs` cannot be a mapping, but must be a named tuple defined globally so Numba can introspect its attributes for pickling. seed (int): See `Portfolio.from_orders`. group_by (any): See `Portfolio.from_orders`. broadcast_named_args (dict): See `Portfolio.from_signals`. broadcast_kwargs (dict): See `Portfolio.from_orders`. template_context (mapping): See `Portfolio.from_signals`. keep_inout_flex (bool): Whether to keep arrays that can be edited in-place raw when broadcasting. Disable this to be able to edit `segment_mask`, `cash_deposits`, and `cash_earnings` during the simulation. jitted (any): See `Portfolio.from_orders`. !!! note Disabling jitting will not disable jitter (such as Numba) on other functions, only on the main (simulation) function. If neccessary, you should ensure that every other function is not compiled as well. For example, when working with Numba, you can do this by using the `py_func` attribute of that function. Or, you can disable Numba entirely by running `os.environ['NUMBA_DISABLE_JIT'] = '1'` before importing vectorbtpro. !!! warning Parallelization assumes that groups are independent and there is no data flowing between them. chunked (any): See `vectorbtpro.utils.chunking.resolve_chunked_option`. staticized (bool, dict, hashable, or callable): Keyword arguments or task id for staticizing. If True or dictionary, will be passed as keyword arguments to `vectorbtpro.utils.cutting.cut_and_save_func` to save a cacheable version of the simulator to a file. If a hashable or callable, will be used as a task id of an already registered jittable and chunkable simulator. Dictionary allows additional options `override` and `reload` to override and reload an already existing module respectively. bm_close (array_like): See `Portfolio.from_orders`. records (array_like): See `Portfolio.from_orders`. return_preparer (bool): See `Portfolio.from_orders`. return_prep_result (bool): See `Portfolio.from_orders`. return_sim_out (bool): See `Portfolio.from_orders`. **kwargs: Keyword arguments passed to the `Portfolio` constructor. For defaults, see `vectorbtpro._settings.portfolio`. Those defaults are not used to fill NaN values after reindexing: vectorbt uses its own sensible defaults, which are usually NaN for floating arrays and default flags for integer arrays. Use `vectorbtpro.base.reshaping.BCO` with `fill_value` to override. !!! note All passed functions must be Numba-compiled if Numba is enabled. Also see notes on `Portfolio.from_orders`. !!! note In contrast to other methods, the valuation price is previous `close` instead of the order price since the price of an order is unknown before the call (which is more realistic by the way). You can still override the valuation price in `pre_segment_func_nb`. Usage: * Buy 10 units each tick using closing price: ```pycon >>> @njit ... def order_func_nb(c, size): ... return vbt.pf_nb.order_nb(size=size) >>> close = pd.Series([1, 2, 3, 4, 5]) >>> pf = vbt.Portfolio.from_order_func( ... close, ... order_func_nb=order_func_nb, ... order_args=(10,) ... ) >>> pf.assets 0 10.0 1 20.0 2 30.0 3 40.0 4 40.0 dtype: float64 >>> pf.cash 0 90.0 1 70.0 2 40.0 3 0.0 4 0.0 dtype: float64 ``` * Reverse each position by first closing it. Keep state of last position to determine which position to open next (just as an example, there are easier ways to do this): ```pycon >>> @njit ... def pre_group_func_nb(c): ... last_pos_state = np.array([-1]) ... return (last_pos_state,) >>> @njit ... def order_func_nb(c, last_pos_state): ... if c.position_now != 0: ... return vbt.pf_nb.close_position_nb() ... ... if last_pos_state[0] == 1: ... size = -np.inf # open short ... last_pos_state[0] = -1 ... else: ... size = np.inf # open long ... last_pos_state[0] = 1 ... return vbt.pf_nb.order_nb(size=size) >>> pf = vbt.Portfolio.from_order_func( ... close, ... order_func_nb=order_func_nb, ... pre_group_func_nb=pre_group_func_nb ... ) >>> pf.assets 0 100.000000 1 0.000000 2 -66.666667 3 0.000000 4 26.666667 dtype: float64 >>> pf.cash 0 0.000000 1 200.000000 2 400.000000 3 133.333333 4 0.000000 dtype: float64 ``` * Equal-weighted portfolio as in the example under `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb`: ```pycon >>> @njit ... def pre_group_func_nb(c): ... order_value_out = np.empty(c.group_len, dtype=float_) ... return (order_value_out,) >>> @njit ... def pre_segment_func_nb(c, order_value_out, size, price, size_type, direction): ... for col in range(c.from_col, c.to_col): ... c.last_val_price[col] = vbt.pf_nb.select_from_col_nb(c, col, price) ... vbt.pf_nb.sort_call_seq_nb(c, size, size_type, direction, order_value_out) ... return () >>> @njit ... def order_func_nb(c, size, price, size_type, direction, fees, fixed_fees, slippage): ... return vbt.pf_nb.order_nb( ... size=vbt.pf_nb.select_nb(c, size), ... price=vbt.pf_nb.select_nb(c, price), ... size_type=vbt.pf_nb.select_nb(c, size_type), ... direction=vbt.pf_nb.select_nb(c, direction), ... fees=vbt.pf_nb.select_nb(c, fees), ... fixed_fees=vbt.pf_nb.select_nb(c, fixed_fees), ... slippage=vbt.pf_nb.select_nb(c, slippage) ... ) >>> np.random.seed(42) >>> close = np.random.uniform(1, 10, size=(5, 3)) >>> size_template = vbt.RepEval('np.array([[1 / group_lens[0]]])') >>> pf = vbt.Portfolio.from_order_func( ... close, ... order_func_nb=order_func_nb, ... order_args=( ... size_template, ... vbt.Rep('price'), ... vbt.Rep('size_type'), ... vbt.Rep('direction'), ... vbt.Rep('fees'), ... vbt.Rep('fixed_fees'), ... vbt.Rep('slippage'), ... ), ... segment_mask=2, # rebalance every second tick ... pre_group_func_nb=pre_group_func_nb, ... pre_segment_func_nb=pre_segment_func_nb, ... pre_segment_args=( ... size_template, ... vbt.Rep('price'), ... vbt.Rep('size_type'), ... vbt.Rep('direction') ... ), ... broadcast_named_args=dict( # broadcast against each other ... price=close, ... size_type=vbt.pf_enums.SizeType.TargetPercent, ... direction=vbt.pf_enums.Direction.LongOnly, ... fees=0.001, ... fixed_fees=1., ... slippage=0.001 ... ), ... template_context=dict(np=np), # required by size_template ... cash_sharing=True, group_by=True, # one group with cash sharing ... ) >>> pf.get_asset_value(group_by=False).vbt.plot().show() ``` ![](/assets/images/api/from_order_func.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_order_func.dark.svg#only-dark){: .iimg loading=lazy } Templates are a very powerful tool to prepare any custom arguments after they are broadcast and before they are passed to the simulation function. In the example above, we use `broadcast_named_args` to broadcast some arguments against each other and templates to pass those objects to callbacks. Additionally, we used an evaluation template to compute the size based on the number of assets in each group. You may ask: why should we bother using broadcasting and templates if we could just pass `size=1/3`? Because of flexibility those features provide: we can now pass whatever parameter combinations we want and it will work flawlessly. For example, to create two groups of equally-allocated positions, we need to change only two parameters: ```pycon >>> close = np.random.uniform(1, 10, size=(5, 6)) # 6 columns instead of 3 >>> group_by = ['g1', 'g1', 'g1', 'g2', 'g2', 'g2'] # 2 groups instead of 1 >>> # Replace close and group_by in the example above >>> pf['g1'].get_asset_value(group_by=False).vbt.plot().show() ``` ![](/assets/images/api/from_order_func_g1.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_order_func_g1.dark.svg#only-dark){: .iimg loading=lazy } ```pycon >>> pf['g2'].get_asset_value(group_by=False).vbt.plot().show() ``` ![](/assets/images/api/from_order_func_g2.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_order_func_g2.dark.svg#only-dark){: .iimg loading=lazy } * Combine multiple exit conditions. Exit early if the price hits some threshold before an actual exit: ```pycon >>> @njit ... def pre_sim_func_nb(c): ... # We need to define stop price per column once ... stop_price = np.full(c.target_shape[1], np.nan, dtype=float_) ... return (stop_price,) >>> @njit ... def order_func_nb(c, stop_price, entries, exits, size): ... # Select info related to this order ... entry_now = vbt.pf_nb.select_nb(c, entries) ... exit_now = vbt.pf_nb.select_nb(c, exits) ... size_now = vbt.pf_nb.select_nb(c, size) ... price_now = vbt.pf_nb.select_nb(c, c.close) ... stop_price_now = stop_price[c.col] ... ... # Our logic ... if entry_now: ... if c.position_now == 0: ... return vbt.pf_nb.order_nb( ... size=size_now, ... price=price_now, ... direction=vbt.pf_enums.Direction.LongOnly) ... elif exit_now or price_now >= stop_price_now: ... if c.position_now > 0: ... return vbt.pf_nb.order_nb( ... size=-size_now, ... price=price_now, ... direction=vbt.pf_enums.Direction.LongOnly) ... return vbt.pf_enums.NoOrder >>> @njit ... def post_order_func_nb(c, stop_price, stop): ... # Same broadcasting as for size ... stop_now = vbt.pf_nb.select_nb(c, stop) ... ... if c.order_result.status == vbt.pf_enums.OrderStatus.Filled: ... if c.order_result.side == vbt.pf_enums.OrderSide.Buy: ... # Position entered: Set stop condition ... stop_price[c.col] = (1 + stop_now) * c.order_result.price ... else: ... # Position exited: Remove stop condition ... stop_price[c.col] = np.nan >>> def simulate(close, entries, exits, size, stop): ... return vbt.Portfolio.from_order_func( ... close, ... order_func_nb=order_func_nb, ... order_args=(vbt.Rep('entries'), vbt.Rep('exits'), vbt.Rep('size')), ... pre_sim_func_nb=pre_sim_func_nb, ... post_order_func_nb=post_order_func_nb, ... post_order_args=(vbt.Rep('stop'),), ... broadcast_named_args=dict( # broadcast against each other ... entries=entries, ... exits=exits, ... size=size, ... stop=stop ... ) ... ) >>> close = pd.Series([10, 11, 12, 13, 14]) >>> entries = pd.Series([True, True, False, False, False]) >>> exits = pd.Series([False, False, False, True, True]) >>> simulate(close, entries, exits, np.inf, 0.1).asset_flow 0 10.0 1 0.0 2 -10.0 3 0.0 4 0.0 dtype: float64 >>> simulate(close, entries, exits, np.inf, 0.2).asset_flow 0 10.0 1 0.0 2 -10.0 3 0.0 4 0.0 dtype: float64 >>> simulate(close, entries, exits, np.inf, np.nan).asset_flow 0 10.0 1 0.0 2 0.0 3 -10.0 4 0.0 dtype: float64 ``` The reason why stop of 10% does not result in an order at the second time step is because it comes at the same time as entry, so it must wait until no entry is present. This can be changed by replacing the statement "elif" with "if", which would execute an exit regardless if an entry is present (similar to using `ConflictMode.Opposite` in `Portfolio.from_signals`). We can also test the parameter combinations above all at once (thanks to broadcasting using `vectorbtpro.base.reshaping.broadcast`): ```pycon >>> stop = pd.DataFrame([[0.1, 0.2, np.nan]]) >>> simulate(close, entries, exits, np.inf, stop).asset_flow 0 1 2 0 10.0 10.0 10.0 1 0.0 0.0 0.0 2 -10.0 -10.0 0.0 3 0.0 0.0 -10.0 4 0.0 0.0 0.0 ``` Or much simpler using Cartesian product: ```pycon >>> stop = vbt.Param([0.1, 0.2, np.nan]) >>> simulate(close, entries, exits, np.inf, stop).asset_flow threshold 0.1 0.2 NaN 0 10.0 10.0 10.0 1 0.0 0.0 0.0 2 -10.0 -10.0 0.0 3 0.0 0.0 -10.0 4 0.0 0.0 0.0 ``` This works because `pd.Index` automatically translates into `vectorbtpro.base.reshaping.BCO` with `product` set to True. * Let's illustrate how to generate multiple orders per symbol and bar. For each bar, buy at open and sell at close: ```pycon >>> @njit ... def flex_order_func_nb(c, size): ... if c.call_idx == 0: ... return c.from_col, vbt.pf_nb.order_nb(size=size, price=c.open[c.i, c.from_col]) ... if c.call_idx == 1: ... return c.from_col, vbt.pf_nb.close_position_nb(price=c.close[c.i, c.from_col]) ... return -1, vbt.pf_enums.NoOrder >>> open = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) >>> close = pd.DataFrame({'a': [2, 3, 4], 'b': [3, 4, 5]}) >>> size = 1 >>> pf = vbt.Portfolio.from_order_func( ... close, ... flex_order_func_nb=flex_order_func_nb, ... flex_order_args=(size,), ... open=open, ... max_order_records=close.shape[0] * 2 ... ) >>> pf.orders.readable Order Id Column Timestamp Size Price Fees Side 0 0 a 0 1.0 1.0 0.0 Buy 1 1 a 0 1.0 2.0 0.0 Sell 2 2 a 1 1.0 2.0 0.0 Buy 3 3 a 1 1.0 3.0 0.0 Sell 4 4 a 2 1.0 3.0 0.0 Buy 5 5 a 2 1.0 4.0 0.0 Sell 6 0 b 0 1.0 4.0 0.0 Buy 7 1 b 0 1.0 3.0 0.0 Sell 8 2 b 1 1.0 5.0 0.0 Buy 9 3 b 1 1.0 4.0 0.0 Sell 10 4 b 2 1.0 6.0 0.0 Buy 11 5 b 2 1.0 5.0 0.0 Sell ``` !!! warning Each bar is effectively a black box - we don't know how the price moves in-between. Since trades should come in an order that closely replicates that of the real world, the only pieces of information that always remain in the correct order are the opening and closing price. """ if isinstance(close, FOFPreparer): preparer = close prep_result = None elif isinstance(close, PFPrepResult): preparer = None prep_result = close else: local_kwargs = locals() local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]} del local_kwargs["kwargs"] del local_kwargs["cls"] del local_kwargs["return_preparer"] del local_kwargs["return_prep_result"] del local_kwargs["return_sim_out"] parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True) if parsed_data is not None: local_kwargs["data"] = parsed_data local_kwargs["close"] = None preparer = FOFPreparer(**local_kwargs) if not return_preparer: preparer.set_seed() prep_result = None if return_preparer: return preparer if prep_result is None: prep_result = preparer.result if return_prep_result: return prep_result sim_out = prep_result.target_func(**prep_result.target_args) if return_sim_out: return sim_out return cls(order_records=sim_out, **prep_result.pf_args) @classmethod def from_def_order_func( cls: tp.Type[PortfolioT], close: tp.Union[tp.ArrayLike, OHLCDataMixin, FDOFPreparer, PFPrepResult], size: tp.Optional[tp.ArrayLike] = None, size_type: tp.Optional[tp.ArrayLike] = None, direction: tp.Optional[tp.ArrayLike] = None, price: tp.Optional[tp.ArrayLike] = None, fees: tp.Optional[tp.ArrayLike] = None, fixed_fees: tp.Optional[tp.ArrayLike] = None, slippage: tp.Optional[tp.ArrayLike] = None, min_size: tp.Optional[tp.ArrayLike] = None, max_size: tp.Optional[tp.ArrayLike] = None, size_granularity: tp.Optional[tp.ArrayLike] = None, leverage: tp.Optional[tp.ArrayLike] = None, leverage_mode: tp.Optional[tp.ArrayLike] = None, reject_prob: tp.Optional[tp.ArrayLike] = None, price_area_vio_mode: tp.Optional[tp.ArrayLike] = None, allow_partial: tp.Optional[tp.ArrayLike] = None, raise_reject: tp.Optional[tp.ArrayLike] = None, log: tp.Optional[tp.ArrayLike] = None, pre_segment_func_nb: tp.Optional[nb.PreSegmentFuncT] = None, order_func_nb: tp.Optional[nb.OrderFuncT] = None, flex_order_func_nb: tp.Optional[nb.FlexOrderFuncT] = None, val_price: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, call_seq: tp.Optional[tp.ArrayLike] = None, flexible: bool = False, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, chunked: tp.ChunkedOption = None, return_preparer: bool = False, return_prep_result: bool = False, return_sim_out: bool = False, **kwargs, ) -> PortfolioResultT: """Build portfolio from the default order function. Default order function takes size, price, fees, and other available information, and issues an order at each column and time step. Additionally, it uses a segment preprocessing function that overrides the valuation price and sorts the call sequence. This way, it behaves similarly to `Portfolio.from_orders`, but allows injecting pre- and postprocessing functions to have more control over the execution. It also knows how to chunk each argument. The only disadvantage is that `Portfolio.from_orders` is more optimized towards performance (up to 5x). If `flexible` is True: * `pre_segment_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_flex_pre_segment_func_nb` * `flex_order_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_flex_order_func_nb` If `flexible` is False: * `pre_segment_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_pre_segment_func_nb` * `order_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_order_func_nb` Prepared by `vectorbtpro.portfolio.preparing.FDOFPreparer`. For details on other arguments, see `Portfolio.from_orders` and `Portfolio.from_order_func`. Usage: * Working with `Portfolio.from_def_order_func` is a similar experience as working with `Portfolio.from_orders`: ```pycon >>> close = pd.Series([1, 2, 3, 4, 5]) >>> pf = vbt.Portfolio.from_def_order_func(close, 10) >>> pf.assets 0 10.0 1 20.0 2 30.0 3 40.0 4 40.0 dtype: float64 >>> pf.cash 0 90.0 1 70.0 2 40.0 3 0.0 4 0.0 dtype: float64 ``` * Equal-weighted portfolio as in the example under `Portfolio.from_order_func` but much less verbose and with asset value pre-computed during the simulation (= faster): ```pycon >>> np.random.seed(42) >>> close = np.random.uniform(1, 10, size=(5, 3)) >>> @njit ... def post_segment_func_nb(c): ... for col in range(c.from_col, c.to_col): ... c.in_outputs.asset_value_pc[c.i, col] = c.last_position[col] * c.last_val_price[col] >>> pf = vbt.Portfolio.from_def_order_func( ... close, ... size=1/3, ... size_type='targetpercent', ... direction='longonly', ... fees=0.001, ... fixed_fees=1., ... slippage=0.001, ... segment_mask=2, ... cash_sharing=True, ... group_by=True, ... call_seq='auto', ... post_segment_func_nb=post_segment_func_nb, ... call_post_segment=True, ... in_outputs=dict(asset_value_pc=vbt.RepEval('np.empty_like(close)')) ... ) >>> asset_value = pf.wrapper.wrap(pf.in_outputs.asset_value_pc, group_by=False) >>> asset_value.vbt.plot().show() ``` ![](/assets/images/api/from_def_order_func.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/from_def_order_func.dark.svg#only-dark){: .iimg loading=lazy } """ if isinstance(close, FDOFPreparer): preparer = close prep_result = None elif isinstance(close, PFPrepResult): preparer = None prep_result = close else: local_kwargs = locals() local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]} del local_kwargs["kwargs"] del local_kwargs["cls"] del local_kwargs["return_preparer"] del local_kwargs["return_prep_result"] del local_kwargs["return_sim_out"] parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True) if parsed_data is not None: local_kwargs["data"] = parsed_data local_kwargs["close"] = None preparer = FDOFPreparer(**local_kwargs) if not return_preparer: preparer.set_seed() prep_result = None if return_preparer: return preparer if prep_result is None: prep_result = preparer.result if return_prep_result: return prep_result sim_out = prep_result.target_func(**prep_result.target_args) if return_sim_out: return sim_out return cls(order_records=sim_out, **prep_result.pf_args) # ############# Grouping ############# # def regroup(self: PortfolioT, group_by: tp.GroupByLike, **kwargs) -> PortfolioT: """Regroup this object. See `vectorbtpro.base.wrapping.Wrapping.regroup`. !!! note All cached objects will be lost.""" if self.cash_sharing: if self.wrapper.grouper.is_grouping_modified(group_by=group_by): raise ValueError("Cannot modify grouping globally when cash_sharing=True") return Wrapping.regroup(self, group_by, **kwargs) # ############# Properties ############# # @property def cash_sharing(self) -> bool: """Whether to share cash within the same group.""" return self._cash_sharing @property def in_outputs(self) -> tp.Optional[tp.NamedTuple]: """Named tuple with in-output objects.""" return self._in_outputs @property def use_in_outputs(self) -> bool: """Whether to return in-output objects when calling properties.""" return self._use_in_outputs @property def fillna_close(self) -> bool: """Whether to forward-backward fill NaN values in `Portfolio.close`.""" return self._fillna_close @property def year_freq(self) -> tp.Optional[tp.PandasFrequency]: """Year frequency.""" return ReturnsAccessor.get_year_freq( year_freq=self._year_freq, index=self.wrapper.index, freq=self.wrapper.freq, ) @property def returns_acc_defaults(self) -> tp.KwargsLike: """Defaults for `vectorbtpro.returns.accessors.ReturnsAccessor`.""" return self._returns_acc_defaults @property def trades_type(self) -> int: """Default `vectorbtpro.portfolio.trades.Trades` to use across `Portfolio`.""" return self._trades_type @property def orders_cls(self) -> type: """Class for wrapping order records.""" if self._orders_cls is None: return Orders return self._orders_cls @property def logs_cls(self) -> type: """Class for wrapping log records.""" if self._logs_cls is None: return Logs return self._logs_cls @property def trades_cls(self) -> type: """Class for wrapping trade records.""" if self._trades_cls is None: return Trades return self._trades_cls @property def entry_trades_cls(self) -> type: """Class for wrapping entry trade records.""" if self._entry_trades_cls is None: return EntryTrades return self._entry_trades_cls @property def exit_trades_cls(self) -> type: """Class for wrapping exit trade records.""" if self._exit_trades_cls is None: return ExitTrades return self._exit_trades_cls @property def positions_cls(self) -> type: """Class for wrapping position records.""" if self._positions_cls is None: return Positions return self._positions_cls @property def drawdowns_cls(self) -> type: """Class for wrapping drawdown records.""" if self._drawdowns_cls is None: return Drawdowns return self._drawdowns_cls @custom_property(group_by_aware=False) def call_seq(self) -> tp.Optional[tp.SeriesFrame]: """Sequence of calls per row and group.""" if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "call_seq"): call_seq = self.in_outputs.call_seq else: call_seq = self._call_seq if call_seq is None: return None return self.wrapper.wrap(call_seq, group_by=False) @property def cash_deposits_as_input(self) -> bool: """Whether to add cash deposits to the input value when calculating returns. Otherwise, will subtract them from the output value.""" return self._cash_deposits_as_input # ############# Price ############# # @property def open_flex(self) -> tp.Optional[tp.ArrayLike]: """`Portfolio.open` in a format suitable for flexible indexing.""" if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "open"): open = self.in_outputs.open else: open = self._open return open @property def high_flex(self) -> tp.Optional[tp.ArrayLike]: """`Portfolio.high` in a format suitable for flexible indexing.""" if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "high"): high = self.in_outputs.high else: high = self._high return high @property def low_flex(self) -> tp.Optional[tp.ArrayLike]: """`Portfolio.low` in a format suitable for flexible indexing.""" if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "low"): low = self.in_outputs.low else: low = self._low return low @property def close_flex(self) -> tp.ArrayLike: """`Portfolio.close` in a format suitable for flexible indexing.""" if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "close"): close = self.in_outputs.close else: close = self._close return close @custom_property(group_by_aware=False, resample_func="first") def open(self) -> tp.Optional[tp.SeriesFrame]: """Open price of each bar.""" if self.open_flex is None: return None return self.wrapper.wrap(self.open_flex, group_by=False) @custom_property(group_by_aware=False, resample_func="max") def high(self) -> tp.Optional[tp.SeriesFrame]: """High price of each bar.""" if self.high_flex is None: return None return self.wrapper.wrap(self.high_flex, group_by=False) @custom_property(group_by_aware=False, resample_func="min") def low(self) -> tp.Optional[tp.SeriesFrame]: """Low price of each bar.""" if self.low_flex is None: return None return self.wrapper.wrap(self.low_flex, group_by=False) @custom_property(group_by_aware=False, resample_func="last") def close(self) -> tp.SeriesFrame: """Last asset price at each time step.""" return self.wrapper.wrap(self.close_flex, group_by=False) @hybrid_method def get_filled_close( cls_or_self, close: tp.Optional[tp.SeriesFrame] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get forward and backward filled closing price. See `vectorbtpro.generic.nb.base.fbfill_nb`.""" if not isinstance(cls_or_self, type): if close is None: close = cls_or_self.close if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(close, arg_name="close") checks.assert_not_none(wrapper, arg_name="wrapper") func = jit_reg.resolve_option(generic_nb.fbfill_nb, jitted) func = ch_reg.resolve_option(func, chunked) filled_close = func(to_2d_array(close)) return wrapper.wrap(filled_close, group_by=False, **resolve_dict(wrap_kwargs)) @custom_property(group_by_aware=False, resample_func="last") def bm_close(self) -> tp.Union[None, bool, tp.SeriesFrame]: """Benchmark price per unit series.""" if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "bm_close"): bm_close = self.in_outputs.bm_close else: bm_close = self._bm_close if bm_close is None or isinstance(bm_close, bool): return bm_close return self.wrapper.wrap(bm_close, group_by=False) @hybrid_method def get_filled_bm_close( cls_or_self, bm_close: tp.Optional[tp.SeriesFrame] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[None, bool, tp.SeriesFrame]: """Get forward and backward filled benchmark closing price. See `vectorbtpro.generic.nb.base.fbfill_nb`.""" if not isinstance(cls_or_self, type): if bm_close is None: bm_close = cls_or_self.bm_close if bm_close is None or isinstance(bm_close, bool): return bm_close if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(bm_close, arg_name="bm_close") checks.assert_not_none(wrapper, arg_name="wrapper") func = jit_reg.resolve_option(generic_nb.fbfill_nb, jitted) func = ch_reg.resolve_option(func, chunked) filled_bm_close = func(to_2d_array(bm_close)) return wrapper.wrap(filled_bm_close, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def get_weights( cls_or_self, weights: tp.Union[None, bool, tp.ArrayLike] = None, wrapper: tp.Optional[ArrayWrapper] = None, ) -> tp.Union[None, tp.ArrayLike, tp.Series]: """Get asset weights.""" if not isinstance(cls_or_self, type): if weights is None: weights = cls_or_self._weights if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") if weights is None or weights is False: return None return wrapper.wrap_reduced(weights, group_by=False) # ############# Views ############# # def apply_weights( self: PortfolioT, weights: tp.Union[None, bool, tp.ArrayLike] = None, rescale: bool = False, group_by: tp.GroupByLike = None, apply_group_by: bool = False, **kwargs, ) -> PortfolioT: """Get view of portfolio with asset weights applied and optionally rescaled. If `rescale` is True, weights are rescaled in respect to other weights in the same group. For example, weights 0.5 and 0.5 are rescaled to 1.0 and 1.0 respectively, while weights 0.7 and 0.3 are rescaled to 1.4 (1.4 * 0.5 = 0.7) and 0.6 (0.6 * 0.5 = 0.3) respectively.""" if weights is not None and weights is not False: weights = to_1d_array(self.get_weights(weights=weights)) if rescale: if self.wrapper.grouper.is_grouped(group_by=group_by): new_weights = np.empty(len(weights), dtype=float_) for group_idxs in self.wrapper.grouper.iter_group_idxs(group_by=group_by): group_weights = weights[group_idxs] new_weights[group_idxs] = group_weights * len(group_weights) / group_weights.sum() weights = new_weights else: weights = weights * len(weights) / weights.sum() if group_by is not None and apply_group_by: _self = self.regroup(group_by=group_by) else: _self = self return _self.replace(weights=weights, **kwargs) def disable_weights(self: PortfolioT, **kwargs) -> PortfolioT: """Get view of portfolio with asset weights disabled.""" return self.replace(weights=False, **kwargs) def get_long_view( self: PortfolioT, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> PortfolioT: """Get view of portfolio with long positions only.""" if orders is None: orders = self.resolve_shortcut_attr( "orders", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, weights=False, ) if init_position is None: init_position = self._init_position if init_price is None: init_price = self._init_price new_order_records = orders.get_long_view( init_position=init_position, init_price=init_price, jitted=jitted, chunked=chunked, ).values init_position = broadcast_array_to(init_position, self.wrapper.shape_2d[1]) init_price = broadcast_array_to(init_price, self.wrapper.shape_2d[1]) new_init_position = np.where(init_position > 0, init_position, 0) new_init_price = np.where(init_position > 0, init_price, np.nan) return self.replace( order_records=new_order_records, init_position=new_init_position, init_price=new_init_price, **kwargs, ) def get_short_view( self: PortfolioT, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> PortfolioT: """Get view of portfolio with short positions only.""" if orders is None: orders = self.resolve_shortcut_attr( "orders", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, weights=False, ) if init_position is None: init_position = self._init_position if init_price is None: init_price = self._init_price new_order_records = orders.get_short_view( init_position=init_position, init_price=init_price, jitted=jitted, chunked=chunked, ).values init_position = broadcast_array_to(init_position, self.wrapper.shape_2d[1]) init_price = broadcast_array_to(init_price, self.wrapper.shape_2d[1]) new_init_position = np.where(init_position < 0, init_position, 0) new_init_price = np.where(init_position < 0, init_price, np.nan) return self.replace( order_records=new_order_records, init_position=new_init_position, init_price=new_init_price, **kwargs, ) # ############# Records ############# # @property def order_records(self) -> tp.RecordArray: """A structured NumPy array of order records.""" return self._order_records @hybrid_method def get_orders( cls_or_self, order_records: tp.Optional[tp.RecordArray] = None, open: tp.Optional[tp.SeriesFrame] = None, high: tp.Optional[tp.SeriesFrame] = None, low: tp.Optional[tp.SeriesFrame] = None, close: tp.Optional[tp.SeriesFrame] = None, orders_cls: tp.Optional[type] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, jitted: tp.JittedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> Orders: """Get order records. See `vectorbtpro.portfolio.orders.Orders`.""" if not isinstance(cls_or_self, type): if order_records is None: order_records = cls_or_self.order_records if open is None: open = cls_or_self.open_flex if high is None: high = cls_or_self.high_flex if low is None: low = cls_or_self.low_flex if close is None: close = cls_or_self.close_flex if orders_cls is None: orders_cls = cls_or_self.orders_cls if weights is None: weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper) elif weights is False: weights = None if wrapper is None: wrapper = fix_wrapper_for_records(cls_or_self) else: checks.assert_not_none(order_records, arg_name="order_records") if orders_cls is None: orders_cls = Orders checks.assert_not_none(wrapper, arg_name="wrapper") weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper) sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) if sim_start is not None or sim_end is not None: func = jit_reg.resolve_option(nb.records_within_sim_range_nb, jitted) order_records = func( wrapper.shape_2d, order_records, order_records["col"], order_records["idx"], sim_start=sim_start, sim_end=sim_end, ) if weights is not None: func = jit_reg.resolve_option(nb.apply_weights_to_orders_nb, jitted) order_records = func( order_records, order_records["col"], to_1d_array(weights), ) return orders_cls( wrapper, order_records, open=open, high=high, low=low, close=close, **kwargs, ).regroup(group_by) @property def log_records(self) -> tp.RecordArray: """A structured NumPy array of log records.""" return self._log_records @hybrid_method def get_logs( cls_or_self, log_records: tp.Optional[tp.RecordArray] = None, open: tp.Optional[tp.SeriesFrame] = None, high: tp.Optional[tp.SeriesFrame] = None, low: tp.Optional[tp.SeriesFrame] = None, close: tp.Optional[tp.SeriesFrame] = None, logs_cls: tp.Optional[type] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> Logs: """Get log records. See `vectorbtpro.portfolio.logs.Logs`.""" if not isinstance(cls_or_self, type): if log_records is None: log_records = cls_or_self.log_records if open is None: open = cls_or_self.open_flex if high is None: high = cls_or_self.high_flex if low is None: low = cls_or_self.low_flex if close is None: close = cls_or_self.close_flex if logs_cls is None: logs_cls = cls_or_self.logs_cls if wrapper is None: wrapper = fix_wrapper_for_records(cls_or_self) else: checks.assert_not_none(log_records, arg_name="log_records") if logs_cls is None: logs_cls = Logs checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) if sim_start is not None or sim_end is not None: func = jit_reg.resolve_option(nb.records_within_sim_range_nb, jitted) log_records = func( wrapper.shape_2d, log_records, log_records["col"], log_records["idx"], sim_start=sim_start, sim_end=sim_end, ) return logs_cls( wrapper, log_records, open=open, high=high, low=low, close=close, **kwargs, ).regroup(group_by) @hybrid_method def get_entry_trades( cls_or_self, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, entry_trades_cls: tp.Optional[type] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> EntryTrades: """Get entry trade records. See `vectorbtpro.portfolio.trades.EntryTrades`.""" if not isinstance(cls_or_self, type): if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if init_price is None: init_price = cls_or_self.resolve_shortcut_attr( "init_price", wrapper=wrapper, keep_flex=True, ) if entry_trades_cls is None: entry_trades_cls = cls_or_self.entry_trades_cls else: checks.assert_not_none(orders, arg_name="orders") if init_position is None: init_position = 0.0 if entry_trades_cls is None: entry_trades_cls = EntryTrades sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, group_by=False) return entry_trades_cls.from_orders( orders, init_position=init_position, init_price=init_price, sim_start=sim_start, sim_end=sim_end, **kwargs, ) @hybrid_method def get_exit_trades( cls_or_self, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, exit_trades_cls: tp.Optional[type] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> ExitTrades: """Get exit trade records. See `vectorbtpro.portfolio.trades.ExitTrades`.""" if not isinstance(cls_or_self, type): if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if init_price is None: init_price = cls_or_self.resolve_shortcut_attr( "init_price", wrapper=wrapper, keep_flex=True, ) if exit_trades_cls is None: exit_trades_cls = cls_or_self.exit_trades_cls else: checks.assert_not_none(orders, arg_name="orders") if init_position is None: init_position = 0.0 if exit_trades_cls is None: exit_trades_cls = ExitTrades sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, group_by=False) return exit_trades_cls.from_orders( orders, init_position=init_position, init_price=init_price, sim_start=sim_start, sim_end=sim_end, **kwargs, ) @hybrid_method def get_positions( cls_or_self, trades: tp.Optional[Trades] = None, positions_cls: tp.Optional[type] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> Positions: """Get position records. See `vectorbtpro.portfolio.trades.Positions`.""" if not isinstance(cls_or_self, type): if trades is None: trades = cls_or_self.resolve_shortcut_attr( "exit_trades", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if positions_cls is None: positions_cls = cls_or_self.positions_cls else: checks.assert_not_none(trades, arg_name="trades") if positions_cls is None: positions_cls = Positions return positions_cls.from_trades(trades, **kwargs) def get_trades( self, trades_type: tp.Optional[tp.Union[str, int]] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> Trades: """Get trade/position records depending upon `Portfolio.trades_type`.""" if trades_type is None: trades_type = self.trades_type else: if isinstance(trades_type, str): trades_type = map_enum_fields(trades_type, enums.TradesType) if trades_type == enums.TradesType.EntryTrades: return self.resolve_shortcut_attr( "entry_trades", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, **kwargs, ) if trades_type == enums.TradesType.ExitTrades: return self.resolve_shortcut_attr( "exit_trades", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, **kwargs, ) if trades_type == enums.TradesType.Positions: return self.resolve_shortcut_attr( "positions", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, **kwargs, ) raise NotImplementedError @hybrid_method def get_trade_history( cls_or_self, orders: tp.Optional[Orders] = None, entry_trades: tp.Optional[EntryTrades] = None, exit_trades: tp.Optional[ExitTrades] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, ) -> tp.Frame: """Get order history merged with entry and exit trades as a readable DataFrame. !!! note The P&L and return aggregated across the DataFrame may not match the actual total P&L and return, as this DataFrame annotates entry and exit orders with the performance relative to their respective trade types. To obtain accurate total statistics, aggregate only the statistics of either trade type. Additionally, entry orders include open statistics, whereas exit orders do not. """ if not isinstance(cls_or_self, type): if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if entry_trades is None: entry_trades = cls_or_self.resolve_shortcut_attr( "entry_trades", orders=orders, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if exit_trades is None: exit_trades = cls_or_self.resolve_shortcut_attr( "exit_trades", orders=orders, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) else: checks.assert_not_none(orders, arg_name="orders") checks.assert_not_none(entry_trades, arg_name="entry_trades") checks.assert_not_none(exit_trades, arg_name="exit_trades") order_history = orders.records_readable del order_history["Size"] del order_history["Price"] del order_history["Fees"] entry_trade_history = entry_trades.records_readable del entry_trade_history["Entry Index"] del entry_trade_history["Exit Order Id"] del entry_trade_history["Exit Index"] del entry_trade_history["Avg Exit Price"] del entry_trade_history["Exit Fees"] entry_trade_history.rename(columns={"Entry Order Id": "Order Id"}, inplace=True) entry_trade_history.rename(columns={"Avg Entry Price": "Price"}, inplace=True) entry_trade_history.rename(columns={"Entry Fees": "Fees"}, inplace=True) exit_trade_history = exit_trades.records_readable del exit_trade_history["Exit Index"] del exit_trade_history["Entry Order Id"] del exit_trade_history["Entry Index"] del exit_trade_history["Avg Entry Price"] del exit_trade_history["Entry Fees"] exit_trade_history.rename(columns={"Exit Order Id": "Order Id"}, inplace=True) exit_trade_history.rename(columns={"Avg Exit Price": "Price"}, inplace=True) exit_trade_history.rename(columns={"Exit Fees": "Fees"}, inplace=True) trade_history = pd.concat((entry_trade_history, exit_trade_history), axis=0) trade_history = pd.merge(order_history, trade_history, on=["Column", "Order Id"]) trade_history = trade_history.sort_values(by=["Column", "Order Id", "Position Id"]) trade_history = trade_history.reset_index(drop=True) trade_history["Entry Trade Id"] = trade_history["Entry Trade Id"].fillna(-1).astype(int) trade_history["Exit Trade Id"] = trade_history["Exit Trade Id"].fillna(-1).astype(int) trade_history["Entry Trade Id"] = trade_history.pop("Entry Trade Id") trade_history["Exit Trade Id"] = trade_history.pop("Exit Trade Id") trade_history["Position Id"] = trade_history.pop("Position Id") return trade_history @hybrid_method def get_signals( cls_or_self, orders: tp.Optional[Orders] = None, entry_trades: tp.Optional[EntryTrades] = None, exit_trades: tp.Optional[ExitTrades] = None, idx_arr: tp.Union[None, str, tp.Array1d] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, ) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame, tp.SeriesFrame, tp.SeriesFrame]: """Get long entries, long exits, short entries, and short exits. Returns per group is grouping is enabled. Pass `group_by=False` to disable.""" if not isinstance(cls_or_self, type): if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if entry_trades is None: entry_trades = cls_or_self.resolve_shortcut_attr( "entry_trades", orders=orders, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if exit_trades is None: exit_trades = cls_or_self.resolve_shortcut_attr( "exit_trades", orders=orders, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) else: checks.assert_not_none(orders, arg_name="orders") checks.assert_not_none(entry_trades, arg_name="entry_trades") checks.assert_not_none(exit_trades, arg_name="exit_trades") if isinstance(orders, FSOrders) and idx_arr is None: idx_arr = "signal_idx" if idx_arr is not None: if isinstance(idx_arr, str): idx_ma = orders.map_field(idx_arr, idx_arr=idx_arr) else: idx_ma = orders.map_array(idx_arr, idx_arr=idx_arr) else: idx_ma = orders.idx order_index = pd.MultiIndex.from_arrays( (orders.col_mapper.get_col_arr(group_by=group_by), orders.id_arr), names=["col", "id"] ) order_idx_sr = pd.Series(idx_ma.values, index=order_index, name="idx") def _get_type_signals(type_order_ids): type_order_ids = type_order_ids.apply_mask(type_order_ids.values != -1) type_order_index = pd.MultiIndex.from_arrays( (type_order_ids.col_mapper.get_col_arr(group_by=group_by), type_order_ids.values), names=["col", "id"] ) type_idx_df = order_idx_sr.loc[type_order_index].reset_index() type_signals = orders.wrapper.fill(False, group_by=group_by) if isinstance(type_signals, pd.Series): type_signals.values[type_idx_df["idx"].values] = True else: type_signals.values[type_idx_df["idx"].values, type_idx_df["col"].values] = True return type_signals return ( _get_type_signals(entry_trades.long_view.entry_order_id), _get_type_signals(exit_trades.long_view.exit_order_id), _get_type_signals(entry_trades.short_view.entry_order_id), _get_type_signals(exit_trades.short_view.exit_order_id), ) @hybrid_method def get_drawdowns( cls_or_self, value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, drawdowns_cls: tp.Optional[type] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> Drawdowns: """Get drawdown records from `Portfolio.get_value`. See `vectorbtpro.generic.drawdowns.Drawdowns`.""" if not isinstance(cls_or_self, type): if value is None: value = cls_or_self.resolve_shortcut_attr( "value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if drawdowns_cls is None: drawdowns_cls = cls_or_self.drawdowns_cls if wrapper is None: wrapper = fix_wrapper_for_records(cls_or_self) else: checks.assert_not_none(value, arg_name="value") if drawdowns_cls is None: drawdowns_cls = Drawdowns sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, group_by=False) if wrapper is not None: wrapper = wrapper.resolve(group_by=group_by) return drawdowns_cls.from_price( value, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, **kwargs, ) # ############# Assets ############# # @hybrid_method def get_init_position( cls_or_self, init_position_raw: tp.Optional[tp.ArrayLike] = None, weights: tp.Union[None, bool, tp.ArrayLike] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, keep_flex: bool = False, ) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]: """Get initial position per column.""" if not isinstance(cls_or_self, type): if init_position_raw is None: init_position_raw = cls_or_self._init_position if weights is None: weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper) elif weights is False: weights = None if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_position_raw, arg_name="init_position_raw") checks.assert_not_none(wrapper, arg_name="wrapper") weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper) if keep_flex and weights is None: return init_position_raw init_position = broadcast_array_to(init_position_raw, wrapper.shape_2d[1]) if weights is not None: weights = to_1d_array(weights) init_position = np.where(np.isnan(weights), init_position, weights * init_position) if keep_flex: return init_position wrap_kwargs = merge_dicts(dict(name_or_index="init_position"), wrap_kwargs) return wrapper.wrap_reduced(init_position, group_by=False, **wrap_kwargs) @hybrid_method def get_asset_flow( cls_or_self, direction: tp.Union[str, int] = "both", orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get asset flow series per column. Returns the total transacted amount of assets at each time step.""" if not isinstance(cls_or_self, type): if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=None, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(orders, arg_name="orders") if init_position is None: init_position = 0.0 if wrapper is None: wrapper = orders.wrapper sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) direction = map_enum_fields(direction, enums.Direction) func = jit_reg.resolve_option(nb.asset_flow_nb, jitted) func = ch_reg.resolve_option(func, chunked) asset_flow = func( wrapper.shape_2d, orders.values, orders.col_mapper.col_map, direction=direction, init_position=to_1d_array(init_position), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(asset_flow, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def get_assets( cls_or_self, direction: tp.Union[str, int] = "both", asset_flow: tp.Optional[tp.SeriesFrame] = None, init_position: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get asset series per column. Returns the position at each time step.""" if not isinstance(cls_or_self, type): if asset_flow is None: asset_flow = cls_or_self.resolve_shortcut_attr( "asset_flow", direction=enums.Direction.Both, init_position=init_position, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(asset_flow, arg_name="asset_flow") if init_position is None: init_position = 0.0 checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) direction = map_enum_fields(direction, enums.Direction) func = jit_reg.resolve_option(nb.assets_nb, jitted) func = ch_reg.resolve_option(func, chunked) assets = func( to_2d_array(asset_flow), direction=direction, init_position=to_1d_array(init_position), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(assets, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def get_position_mask( cls_or_self, direction: tp.Union[str, int] = "both", assets: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get position mask per column or group. An element is True if there is a position at the given time step.""" if not isinstance(cls_or_self, type): if assets is None: assets = cls_or_self.resolve_shortcut_attr( "assets", direction=direction, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(assets, arg_name="assets") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.position_mask_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) position_mask = func( to_2d_array(assets), group_lens=group_lens, sim_start=sim_start, sim_end=sim_end, ) else: func = jit_reg.resolve_option(nb.position_mask_nb, jitted) func = ch_reg.resolve_option(func, chunked) position_mask = func( to_2d_array(assets), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(position_mask, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_position_coverage( cls_or_self, direction: tp.Union[str, int] = "both", assets: tp.Optional[tp.SeriesFrame] = None, granular_groups: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get position coverage per column or group. Position coverage is the number of time steps in the market divided by the total number of time steps.""" if not isinstance(cls_or_self, type): if assets is None: assets = cls_or_self.resolve_shortcut_attr( "assets", direction=direction, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(assets, arg_name="assets") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.position_coverage_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) position_coverage = func( to_2d_array(assets), group_lens=group_lens, granular_groups=granular_groups, sim_start=sim_start, sim_end=sim_end, ) else: func = jit_reg.resolve_option(nb.position_coverage_nb, jitted) func = ch_reg.resolve_option(func, chunked) position_coverage = func( to_2d_array(assets), sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="position_coverage"), wrap_kwargs) return wrapper.wrap_reduced(position_coverage, group_by=group_by, **wrap_kwargs) @hybrid_method def get_position_entry_price( cls_or_self, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, fill_closed_position: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get the position's entry price at each time step.""" if not isinstance(cls_or_self, type): if orders is None: if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=None, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if init_price is None: init_price = cls_or_self.resolve_shortcut_attr( "init_price", wrapper=wrapper, keep_flex=True, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(orders, arg_name="orders") if init_position is None: init_position = 0.0 if init_price is None: init_price = np.nan if wrapper is None: wrapper = orders.wrapper sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.get_position_feature_nb, jitted) func = ch_reg.resolve_option(func, chunked) entry_price = func( orders.values, to_2d_array(orders.close), orders.col_mapper.col_map, feature=enums.PositionFeature.EntryPrice, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), fill_closed_position=fill_closed_position, sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(entry_price, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def get_position_exit_price( cls_or_self, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, fill_closed_position: bool = False, fill_exit_price: bool = True, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get the position's exit price at each time step.""" if not isinstance(cls_or_self, type): if orders is None: if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=None, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if init_price is None: init_price = cls_or_self.resolve_shortcut_attr( "init_price", wrapper=wrapper, keep_flex=True, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(orders, arg_name="orders") if init_position is None: init_position = 0.0 if init_price is None: init_price = np.nan if wrapper is None: wrapper = orders.wrapper sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.get_position_feature_nb, jitted) func = ch_reg.resolve_option(func, chunked) exit_price = func( orders.values, to_2d_array(orders.close), orders.col_mapper.col_map, feature=enums.PositionFeature.ExitPrice, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), fill_closed_position=fill_closed_position, fill_exit_price=fill_exit_price, sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(exit_price, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Cash ############# # @hybrid_method def get_cash_deposits( cls_or_self, cash_deposits_raw: tp.Optional[tp.ArrayLike] = None, cash_sharing: tp.Optional[bool] = None, split_shared: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, keep_flex: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]: """Get cash deposit series per column or group. Set `keep_flex` to True to keep format suitable for flexible indexing. This consumes less memory.""" if not isinstance(cls_or_self, type): if cash_deposits_raw is None: cash_deposits_raw = cls_or_self._cash_deposits if cash_sharing is None: cash_sharing = cls_or_self.cash_sharing if weights is None: weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper) elif weights is False: weights = None if wrapper is None: wrapper = cls_or_self.wrapper else: if cash_deposits_raw is None: cash_deposits_raw = 0.0 checks.assert_not_none(cash_sharing, arg_name="cash_sharing") checks.assert_not_none(wrapper, arg_name="wrapper") weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper) sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) cash_deposits_arr = to_2d_array(cash_deposits_raw) if keep_flex and not cash_deposits_arr.any(): return cash_deposits_raw if wrapper.grouper.is_grouped(group_by=group_by): if keep_flex and cash_sharing and weights is None and sim_start is None and sim_end is None: return cash_deposits_raw group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.cash_deposits_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash_deposits = func( wrapper.shape_2d, group_lens, cash_sharing, cash_deposits_raw=cash_deposits_arr, weights=to_1d_array(weights) if weights is not None else None, sim_start=sim_start, sim_end=sim_end, ) else: if keep_flex and not cash_sharing and weights is None and sim_start is None and sim_end is None: return cash_deposits_raw group_lens = wrapper.grouper.get_group_lens() func = jit_reg.resolve_option(nb.cash_deposits_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash_deposits = func( wrapper.shape_2d, group_lens, cash_sharing, cash_deposits_raw=cash_deposits_arr, split_shared=split_shared, weights=to_1d_array(weights) if weights is not None else None, sim_start=sim_start, sim_end=sim_end, ) if keep_flex: return cash_deposits return wrapper.wrap(cash_deposits, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_total_cash_deposits( cls_or_self, cash_deposits_raw: tp.Optional[tp.ArrayLike] = None, cash_sharing: tp.Optional[bool] = None, split_shared: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.ArrayLike: """Get total cash deposit series per column or group.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") cash_deposits = cls_or_self.get_cash_deposits( cash_deposits_raw=cash_deposits_raw, cash_sharing=cash_sharing, split_shared=split_shared, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, weights=weights, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, wrap_kwargs=wrap_kwargs, ) total_cash_deposits = np.nansum(cash_deposits, axis=0) return wrapper.wrap_reduced(total_cash_deposits, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_cash_earnings( cls_or_self, cash_earnings_raw: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, keep_flex: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]: """Get earnings in cash series per column or group. Set `keep_flex` to True to keep format suitable for flexible indexing. This consumes less memory.""" if not isinstance(cls_or_self, type): if cash_earnings_raw is None: cash_earnings_raw = cls_or_self._cash_earnings if weights is None: weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper) elif weights is False: weights = None if wrapper is None: wrapper = cls_or_self.wrapper else: if cash_earnings_raw is None: cash_earnings_raw = 0.0 checks.assert_not_none(wrapper, arg_name="wrapper") weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper) sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) cash_earnings_arr = to_2d_array(cash_earnings_raw) if keep_flex and not cash_earnings_arr.any(): return cash_earnings_raw if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.cash_earnings_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash_earnings = func( wrapper.shape_2d, group_lens, cash_earnings_raw=cash_earnings_arr, weights=to_1d_array(weights) if weights is not None else None, sim_start=sim_start, sim_end=sim_end, ) else: if keep_flex and weights is None and sim_start is None and sim_end is None: return cash_earnings_raw func = jit_reg.resolve_option(nb.cash_earnings_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash_earnings = func( wrapper.shape_2d, cash_earnings_raw=cash_earnings_arr, weights=to_1d_array(weights) if weights is not None else None, sim_start=sim_start, sim_end=sim_end, ) if keep_flex: return cash_earnings return wrapper.wrap(cash_earnings, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_total_cash_earnings( cls_or_self, cash_earnings_raw: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.ArrayLike: """Get total cash earning series per column or group.""" if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") cash_earnings = cls_or_self.get_cash_earnings( cash_earnings_raw=cash_earnings_raw, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, weights=weights, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, wrap_kwargs=wrap_kwargs, ) total_cash_earnings = np.nansum(cash_earnings, axis=0) return wrapper.wrap_reduced(total_cash_earnings, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_cash_flow( cls_or_self, free: bool = False, orders: tp.Optional[Orders] = None, cash_earnings: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get cash flow series per column or group. Use `free` to return the flow of the free cash, which never goes above the initial level, because an operation always costs money. !!! note Does not include cash deposits, but includes earnings. Using `free` yields the same result as during the simulation only when `leverage=1`. For anything else, prefill the state instead of reconstructing it.""" if not isinstance(cls_or_self, type): if orders is None: if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, weights=weights, wrapper=wrapper, group_by=None, ) if cash_earnings is None: cash_earnings = cls_or_self.resolve_shortcut_attr( "cash_earnings", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, weights=weights, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(orders, arg_name="orders") if cash_earnings is None: cash_earnings = 0.0 if wrapper is None: wrapper = orders.wrapper sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.cash_flow_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash_flow = func( wrapper.shape_2d, orders.values, orders.col_mapper.col_map, free=free, cash_earnings=to_2d_array(cash_earnings), sim_start=sim_start, sim_end=sim_end, ) if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.cash_flow_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash_flow = func( cash_flow, group_lens, sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(cash_flow, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_init_cash( cls_or_self, init_cash_raw: tp.Optional[tp.ArrayLike] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, free_cash_flow: tp.Optional[tp.SeriesFrame] = None, cash_sharing: tp.Optional[bool] = None, split_shared: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, weights: tp.Union[None, bool, tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get initial amount of cash per column or group.""" if not isinstance(cls_or_self, type): if init_cash_raw is None: init_cash_raw = cls_or_self._init_cash if checks.is_int(init_cash_raw) and init_cash_raw in enums.InitCashMode: if cash_deposits is None: cash_deposits = cls_or_self.resolve_shortcut_attr( "cash_deposits", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, weights=weights, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if free_cash_flow is None: free_cash_flow = cls_or_self.resolve_shortcut_attr( "cash_flow", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, weights=weights, free=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_sharing is None: cash_sharing = cls_or_self.cash_sharing if weights is None: weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper) elif weights is False: weights = None if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_cash_raw, arg_name="init_cash_raw") if checks.is_int(init_cash_raw) and init_cash_raw in enums.InitCashMode: checks.assert_not_none(free_cash_flow, arg_name="free_cash_flow") if cash_deposits is None: cash_deposits = 0.0 checks.assert_not_none(cash_sharing, arg_name="cash_sharing") checks.assert_not_none(wrapper, arg_name="wrapper") weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper) sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) if checks.is_int(init_cash_raw) and init_cash_raw in enums.InitCashMode: func = jit_reg.resolve_option(nb.align_init_cash_nb, jitted) func = ch_reg.resolve_option(func, chunked) init_cash = func( init_cash_raw, to_2d_array(free_cash_flow), cash_deposits=to_2d_array(cash_deposits), sim_start=sim_start, sim_end=sim_end, ) else: init_cash_raw = to_1d_array(init_cash_raw) if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.init_cash_grouped_nb, jitted) init_cash = func( init_cash_raw, group_lens, cash_sharing, weights=to_1d_array(weights) if weights is not None else None, ) else: group_lens = wrapper.grouper.get_group_lens() func = jit_reg.resolve_option(nb.init_cash_nb, jitted) init_cash = func( init_cash_raw, group_lens, cash_sharing, split_shared=split_shared, weights=to_1d_array(weights) if weights is not None else None, ) wrap_kwargs = merge_dicts(dict(name_or_index="init_cash"), wrap_kwargs) return wrapper.wrap_reduced(init_cash, group_by=group_by, **wrap_kwargs) @hybrid_method def get_cash( cls_or_self, free: bool = False, init_cash: tp.Optional[tp.ArrayLike] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, cash_flow: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get cash balance series per column or group. For `free`, see `Portfolio.get_cash_flow`.""" if not isinstance(cls_or_self, type): if init_cash is None: init_cash = cls_or_self.resolve_shortcut_attr( "init_cash", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_deposits is None: cash_deposits = cls_or_self.resolve_shortcut_attr( "cash_deposits", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_flow is None: cash_flow = cls_or_self.resolve_shortcut_attr( "cash_flow", free=free, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_cash, arg_name="init_cash") if cash_deposits is None: cash_deposits = 0.0 checks.assert_not_none(cash_flow, arg_name="cash_flow") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.cash_nb, jitted) func = ch_reg.resolve_option(func, chunked) cash = func( to_2d_array(cash_flow), to_1d_array(init_cash), cash_deposits=to_2d_array(cash_deposits), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(cash, group_by=group_by, **resolve_dict(wrap_kwargs)) # ############# Value ############# # @hybrid_method def get_init_price( cls_or_self, init_price_raw: tp.Optional[tp.ArrayLike] = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, keep_flex: bool = False, ) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]: """Get initial price per column.""" if not isinstance(cls_or_self, type): if init_price_raw is None: init_price_raw = cls_or_self._init_price if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_price_raw, arg_name="init_price_raw") checks.assert_not_none(wrapper, arg_name="wrapper") if keep_flex: return init_price_raw init_price = broadcast_array_to(init_price_raw, wrapper.shape_2d[1]) if keep_flex: return init_price wrap_kwargs = merge_dicts(dict(name_or_index="init_price"), wrap_kwargs) return wrapper.wrap_reduced(init_price, group_by=False, **wrap_kwargs) @hybrid_method def get_init_position_value( cls_or_self, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get initial position value per column.""" if not isinstance(cls_or_self, type): if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if init_price is None: init_price = cls_or_self.resolve_shortcut_attr( "init_price", wrapper=wrapper, keep_flex=True, ) if wrapper is None: wrapper = cls_or_self.wrapper else: if init_position is None: init_position = 0.0 checks.assert_not_none(init_price, arg_name="init_price") checks.assert_not_none(wrapper, arg_name="wrapper") if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.init_position_value_grouped_nb, jitted) init_position_value = func( group_lens, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), ) else: func = jit_reg.resolve_option(nb.init_position_value_nb, jitted) init_position_value = func( wrapper.shape_2d[1], init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), ) wrap_kwargs = merge_dicts(dict(name_or_index="init_position_value"), wrap_kwargs) return wrapper.wrap_reduced(init_position_value, group_by=group_by, **wrap_kwargs) @hybrid_method def get_init_value( cls_or_self, init_position_value: tp.Optional[tp.MaybeSeries] = None, init_cash: tp.Optional[tp.MaybeSeries] = None, split_shared: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get initial value per column or group. Includes initial cash and the value of initial position.""" if not isinstance(cls_or_self, type): if init_position_value is None: init_position_value = cls_or_self.resolve_shortcut_attr( "init_position_value", jitted=jitted, wrapper=wrapper, group_by=group_by, ) if init_cash is None: init_cash = cls_or_self.resolve_shortcut_attr( "init_cash", split_shared=split_shared, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_position_value, arg_name="init_position_value") checks.assert_not_none(init_cash, arg_name="init_cash") checks.assert_not_none(wrapper, arg_name="wrapper") func = jit_reg.resolve_option(nb.init_value_nb, jitted) init_value = func(to_1d_array(init_position_value), to_1d_array(init_cash)) wrap_kwargs = merge_dicts(dict(name_or_index="init_value"), wrap_kwargs) return wrapper.wrap_reduced(init_value, group_by=group_by, **wrap_kwargs) @hybrid_method def get_input_value( cls_or_self, total_cash_deposits: tp.Optional[tp.ArrayLike] = None, init_value: tp.Optional[tp.MaybeSeries] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get total input value per column or group. Includes initial value and any cash deposited at any point in time.""" if not isinstance(cls_or_self, type): if total_cash_deposits is None: total_cash_deposits = cls_or_self.resolve_shortcut_attr( "total_cash_deposits", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: if total_cash_deposits is None: total_cash_deposits = 0.0 checks.assert_not_none(init_value, arg_name="init_value") checks.assert_not_none(wrapper, arg_name="wrapper") input_value = to_1d_array(total_cash_deposits) + to_1d_array(init_value) wrap_kwargs = merge_dicts(dict(name_or_index="input_value"), wrap_kwargs) return wrapper.wrap_reduced(input_value, group_by=group_by, **wrap_kwargs) @hybrid_method def get_asset_value( cls_or_self, direction: tp.Union[str, int] = "both", close: tp.Optional[tp.SeriesFrame] = None, assets: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get asset value series per column or group.""" if not isinstance(cls_or_self, type): if close is None: if cls_or_self.fillna_close: close = cls_or_self.filled_close else: close = cls_or_self.close if assets is None: assets = cls_or_self.resolve_shortcut_attr( "assets", direction=direction, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(close, arg_name="close") checks.assert_not_none(assets, arg_name="assets") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.asset_value_nb, jitted) func = ch_reg.resolve_option(func, chunked) asset_value = func( to_2d_array(close), to_2d_array(assets), sim_start=sim_start, sim_end=sim_end, ) if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.asset_value_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) asset_value = func( asset_value, group_lens, sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(asset_value, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_value( cls_or_self, cash: tp.Optional[tp.SeriesFrame] = None, asset_value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get portfolio value series per column or group. By default, will generate portfolio value for each asset based on cash flows and thus independent of other assets, with the initial cash balance and position being that of the entire group. Useful for generating returns and comparing assets within the same group.""" if not isinstance(cls_or_self, type): if cash is None: cash = cls_or_self.resolve_shortcut_attr( "cash", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if asset_value is None: asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(cash, arg_name="cash") checks.assert_not_none(asset_value, arg_name="asset_value") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.value_nb, jitted) func = ch_reg.resolve_option(func, chunked) value = func( to_2d_array(cash), to_2d_array(asset_value), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(value, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_gross_exposure( cls_or_self, direction: tp.Union[str, int] = "both", asset_value: tp.Optional[tp.SeriesFrame] = None, value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get gross exposure. !!! note When both directions, `asset_value` must include the addition of the absolute long-only and short-only asset values.""" direction = map_enum_fields(direction, enums.Direction) if not isinstance(cls_or_self, type): if asset_value is None: if direction == enums.Direction.Both and cls_or_self.wrapper.grouper.is_grouped(group_by=group_by): long_asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", direction="longonly", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) short_asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", direction="shortonly", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) asset_value = long_asset_value + short_asset_value else: asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", direction=direction, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if value is None: value = cls_or_self.resolve_shortcut_attr( "value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(asset_value, arg_name="asset_value") checks.assert_not_none(value, arg_name="value") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.gross_exposure_nb, jitted) func = ch_reg.resolve_option(func, chunked) gross_exposure = func( to_2d_array(asset_value), to_2d_array(value), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(gross_exposure, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_net_exposure( cls_or_self, long_exposure: tp.Optional[tp.SeriesFrame] = None, short_exposure: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get net exposure.""" if not isinstance(cls_or_self, type): if long_exposure is None: long_exposure = cls_or_self.resolve_shortcut_attr( "gross_exposure", direction=enums.Direction.LongOnly, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if short_exposure is None: short_exposure = cls_or_self.resolve_shortcut_attr( "gross_exposure", direction=enums.Direction.ShortOnly, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(long_exposure, arg_name="long_exposure") checks.assert_not_none(short_exposure, arg_name="short_exposure") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.net_exposure_nb, jitted) func = ch_reg.resolve_option(func, chunked) net_exposure = func( to_2d_array(long_exposure), to_2d_array(short_exposure), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(net_exposure, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_allocations( cls_or_self, direction: tp.Union[str, int] = "both", asset_value: tp.Optional[tp.SeriesFrame] = None, value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get portfolio allocation series per column.""" if not isinstance(cls_or_self, type): if asset_value is None: asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", direction=direction, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) if value is None: value = cls_or_self.resolve_shortcut_attr( "value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(asset_value, arg_name="asset_value") checks.assert_not_none(value, arg_name="value") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.allocations_nb, jitted) func = ch_reg.resolve_option(func, chunked) allocations = func( to_2d_array(asset_value), to_2d_array(value), group_lens, sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(allocations, group_by=False, **resolve_dict(wrap_kwargs)) @hybrid_method def get_total_profit( cls_or_self, close: tp.Optional[tp.SeriesFrame] = None, orders: tp.Optional[Orders] = None, init_position: tp.Optional[tp.ArrayLike] = None, init_price: tp.Optional[tp.ArrayLike] = None, cash_earnings: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get total profit per column or group. Calculated directly from order records (fast).""" if not isinstance(cls_or_self, type): if close is None: if cls_or_self.fillna_close: close = cls_or_self.filled_close else: close = cls_or_self.close if orders is None: if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=None, ) if init_position is None: init_position = cls_or_self.resolve_shortcut_attr( "init_position", wrapper=wrapper, keep_flex=True, ) if init_price is None: init_price = cls_or_self.resolve_shortcut_attr( "init_price", wrapper=wrapper, keep_flex=True, ) if cash_earnings is None: cash_earnings = cls_or_self.resolve_shortcut_attr( "cash_earnings", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(orders, arg_name="orders") if close is None: close = orders.close checks.assert_not_none(close, arg_name="close") checks.assert_not_none(init_price, arg_name="init_price") if init_position is None: init_position = 0.0 if cash_earnings is None: cash_earnings = 0.0 if wrapper is None: wrapper = orders.wrapper sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.total_profit_nb, jitted) func = ch_reg.resolve_option(func, chunked) total_profit = func( wrapper.shape_2d, to_2d_array(close), orders.values, orders.col_mapper.col_map, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), cash_earnings=to_2d_array(cash_earnings), sim_start=sim_start, sim_end=sim_end, ) if wrapper.grouper.is_grouped(group_by=group_by): group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.total_profit_grouped_nb, jitted) total_profit = func(total_profit, group_lens) wrap_kwargs = merge_dicts(dict(name_or_index="total_profit"), wrap_kwargs) return wrapper.wrap_reduced(total_profit, group_by=group_by, **wrap_kwargs) @hybrid_method def get_final_value( cls_or_self, input_value: tp.Optional[tp.MaybeSeries] = None, total_profit: tp.Optional[tp.MaybeSeries] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get total profit per column or group.""" if not isinstance(cls_or_self, type): if input_value is None: input_value = cls_or_self.resolve_shortcut_attr( "input_value", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if total_profit is None: total_profit = cls_or_self.resolve_shortcut_attr( "total_profit", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(input_value, arg_name="input_value") checks.assert_not_none(total_profit, arg_name="total_profit") checks.assert_not_none(wrapper, arg_name="wrapper") final_value = to_1d_array(input_value) + to_1d_array(total_profit) wrap_kwargs = merge_dicts(dict(name_or_index="final_value"), wrap_kwargs) return wrapper.wrap_reduced(final_value, group_by=group_by, **wrap_kwargs) @hybrid_method def get_total_return( cls_or_self, input_value: tp.Optional[tp.MaybeSeries] = None, total_profit: tp.Optional[tp.MaybeSeries] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get total return per column or group.""" if not isinstance(cls_or_self, type): if input_value is None: input_value = cls_or_self.resolve_shortcut_attr( "input_value", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if total_profit is None: total_profit = cls_or_self.resolve_shortcut_attr( "total_profit", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(input_value, arg_name="input_value") checks.assert_not_none(total_profit, arg_name="total_profit") checks.assert_not_none(wrapper, arg_name="wrapper") total_return = to_1d_array(total_profit) / to_1d_array(input_value) wrap_kwargs = merge_dicts(dict(name_or_index="total_return"), wrap_kwargs) return wrapper.wrap_reduced(total_return, group_by=group_by, **wrap_kwargs) @hybrid_method def get_returns( cls_or_self, init_value: tp.Optional[tp.MaybeSeries] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, cash_deposits_as_input: tp.Optional[bool] = None, value: tp.Optional[tp.SeriesFrame] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get return series per column or group based on portfolio value.""" if not isinstance(cls_or_self, type): if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_deposits is None: cash_deposits = cls_or_self.resolve_shortcut_attr( "cash_deposits", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_deposits_as_input is None: cash_deposits_as_input = cls_or_self.cash_deposits_as_input if value is None: value = cls_or_self.resolve_shortcut_attr( "value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_value, arg_name="init_value") if cash_deposits is None: cash_deposits = 0.0 if cash_deposits_as_input is None: cash_deposits_as_input = False checks.assert_not_none(value, arg_name="value") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.returns_nb, jitted) func = ch_reg.resolve_option(func, chunked) returns = func( to_2d_array(value), to_1d_array(init_value), cash_deposits=to_2d_array(cash_deposits), cash_deposits_as_input=cash_deposits_as_input, log_returns=log_returns, sim_start=sim_start, sim_end=sim_end, ) returns = wrapper.wrap(returns, group_by=group_by, **resolve_dict(wrap_kwargs)) if daily_returns: returns = returns.vbt.returns(log_returns=log_returns).daily(jitted=jitted) return returns @hybrid_method def get_asset_pnl( cls_or_self, init_position_value: tp.Optional[tp.MaybeSeries] = None, asset_value: tp.Optional[tp.SeriesFrame] = None, cash_flow: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get asset (realized and unrealized) PnL series per column or group.""" if not isinstance(cls_or_self, type): if init_position_value is None: init_position_value = cls_or_self.resolve_shortcut_attr( "init_position_value", jitted=jitted, wrapper=wrapper, group_by=group_by, ) if asset_value is None: asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_flow is None: cash_flow = cls_or_self.resolve_shortcut_attr( "cash_flow", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_position_value, arg_name="init_position_value") checks.assert_not_none(asset_value, arg_name="asset_value") checks.assert_not_none(cash_flow, arg_name="cash_flow") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.asset_pnl_nb, jitted) func = ch_reg.resolve_option(func, chunked) asset_pnl = func( to_2d_array(asset_value), to_2d_array(cash_flow), init_position_value=to_1d_array(init_position_value), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(asset_pnl, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_asset_returns( cls_or_self, init_position_value: tp.Optional[tp.MaybeSeries] = None, asset_value: tp.Optional[tp.SeriesFrame] = None, cash_flow: tp.Optional[tp.SeriesFrame] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get asset return series per column or group. This type of returns is based solely on cash flows and asset value rather than portfolio value. It ignores passive cash and thus it will return the same numbers irrespective of the amount of cash currently available, even `np.inf`. The scale of returns is comparable to that of going all in and keeping available cash at zero.""" if not isinstance(cls_or_self, type): if init_position_value is None: init_position_value = cls_or_self.resolve_shortcut_attr( "init_position_value", jitted=jitted, wrapper=wrapper, group_by=group_by, ) if asset_value is None: asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_flow is None: cash_flow = cls_or_self.resolve_shortcut_attr( "cash_flow", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_position_value, arg_name="init_position_value") checks.assert_not_none(asset_value, arg_name="asset_value") checks.assert_not_none(cash_flow, arg_name="cash_flow") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.asset_returns_nb, jitted) func = ch_reg.resolve_option(func, chunked) asset_returns = func( to_2d_array(asset_value), to_2d_array(cash_flow), init_position_value=to_1d_array(init_position_value), log_returns=log_returns, sim_start=sim_start, sim_end=sim_end, ) asset_returns = wrapper.wrap(asset_returns, group_by=group_by, **resolve_dict(wrap_kwargs)) if daily_returns: asset_returns = asset_returns.vbt.returns(log_returns=log_returns).daily(jitted=jitted) return asset_returns @hybrid_method def get_market_value( cls_or_self, close: tp.Optional[tp.SeriesFrame] = None, init_value: tp.Optional[tp.MaybeSeries] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get market value series per column or group. If grouped, evenly distributes the initial cash among assets in the group. !!! note Does not take into account fees and slippage. For this, create a separate portfolio.""" if not isinstance(cls_or_self, type): if close is None: if cls_or_self.fillna_close: close = cls_or_self.filled_close else: close = cls_or_self.close if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(close, arg_name="close") checks.assert_not_none(init_value, arg_name="init_value") if cash_deposits is None: cash_deposits = 0.0 checks.assert_not_none(wrapper, arg_name="wrapper") if wrapper.grouper.is_grouped(group_by=group_by): if not isinstance(cls_or_self, type): if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", split_shared=True, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) if cash_deposits is None: cash_deposits = cls_or_self.resolve_shortcut_attr( "cash_deposits", split_shared=True, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) group_lens = wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.market_value_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) market_value = func( to_2d_array(close), group_lens, to_1d_array(init_value), cash_deposits=to_2d_array(cash_deposits), sim_start=sim_start, sim_end=sim_end, ) else: if not isinstance(cls_or_self, type): if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) if cash_deposits is None: cash_deposits = cls_or_self.resolve_shortcut_attr( "cash_deposits", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=False, ) sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.market_value_nb, jitted) func = ch_reg.resolve_option(func, chunked) market_value = func( to_2d_array(close), to_1d_array(init_value), cash_deposits=to_2d_array(cash_deposits), sim_start=sim_start, sim_end=sim_end, ) return wrapper.wrap(market_value, group_by=group_by, **resolve_dict(wrap_kwargs)) @hybrid_method def get_market_returns( cls_or_self, init_value: tp.Optional[tp.MaybeSeries] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, cash_deposits_as_input: tp.Optional[bool] = None, market_value: tp.Optional[tp.SeriesFrame] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get market return series per column or group.""" if not isinstance(cls_or_self, type): if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_deposits is None: cash_deposits = cls_or_self.resolve_shortcut_attr( "cash_deposits", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, keep_flex=True, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if cash_deposits_as_input is None: cash_deposits_as_input = cls_or_self.cash_deposits_as_input if market_value is None: market_value = cls_or_self.resolve_shortcut_attr( "market_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_value, arg_name="init_value") if cash_deposits is None: cash_deposits = 0.0 if cash_deposits_as_input is None: cash_deposits_as_input = False checks.assert_not_none(market_value, arg_name="market_value") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.returns_nb, jitted) func = ch_reg.resolve_option(func, chunked) market_returns = func( to_2d_array(market_value), to_1d_array(init_value), cash_deposits=to_2d_array(cash_deposits), cash_deposits_as_input=cash_deposits_as_input, log_returns=log_returns, sim_start=sim_start, sim_end=sim_end, ) market_returns = wrapper.wrap(market_returns, group_by=group_by, **resolve_dict(wrap_kwargs)) if daily_returns: market_returns = market_returns.vbt.returns(log_returns=log_returns).daily(jitted=jitted) return market_returns @hybrid_method def get_total_market_return( cls_or_self, input_value: tp.Optional[tp.MaybeSeries] = None, market_value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Get total market return.""" if not isinstance(cls_or_self, type): if input_value is None: input_value = cls_or_self.resolve_shortcut_attr( "input_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if market_value is None: market_value = cls_or_self.resolve_shortcut_attr( "market_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(input_value, arg_name="input_value") checks.assert_not_none(market_value, arg_name="market_value") checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) func = jit_reg.resolve_option(nb.total_market_return_nb, jitted) total_market_return = func( to_2d_array(market_value), to_1d_array(input_value), sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="total_market_return"), wrap_kwargs) return wrapper.wrap_reduced(total_market_return, group_by=group_by, **wrap_kwargs) @hybrid_method def get_bm_value( cls_or_self, bm_close: tp.Optional[tp.ArrayLike] = None, init_value: tp.Optional[tp.MaybeSeries] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Optional[tp.SeriesFrame]: """Get benchmark value series per column or group. Based on `Portfolio.bm_close` and `Portfolio.get_market_value`.""" if not isinstance(cls_or_self, type): if bm_close is None: bm_close = cls_or_self.bm_close if isinstance(bm_close, bool): if not bm_close: return None bm_close = None if bm_close is not None: if cls_or_self.fillna_close: bm_close = cls_or_self.filled_bm_close return cls_or_self.get_market_value( close=bm_close, init_value=init_value, cash_deposits=cash_deposits, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, wrap_kwargs=wrap_kwargs, ) @hybrid_method def get_bm_returns( cls_or_self, init_value: tp.Optional[tp.MaybeSeries] = None, cash_deposits: tp.Optional[tp.ArrayLike] = None, bm_value: tp.Optional[tp.SeriesFrame] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Optional[tp.SeriesFrame]: """Get benchmark return series per column or group. Based on `Portfolio.bm_close` and `Portfolio.get_market_returns`.""" if not isinstance(cls_or_self, type): bm_value = cls_or_self.resolve_shortcut_attr( "bm_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if bm_value is None: return None return cls_or_self.get_market_returns( init_value=init_value, cash_deposits=cash_deposits, market_value=bm_value, log_returns=log_returns, daily_returns=daily_returns, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, wrap_kwargs=wrap_kwargs, ) @hybrid_method def get_returns_acc( cls_or_self, returns: tp.Optional[tp.SeriesFrame] = None, use_asset_returns: bool = False, bm_returns: tp.Union[None, bool, tp.ArrayLike] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, freq: tp.Optional[tp.FrequencyLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> ReturnsAccessor: """Get returns accessor of type `vectorbtpro.returns.accessors.ReturnsAccessor`. !!! hint You can find most methods of this accessor as (cacheable) attributes of this portfolio.""" if not isinstance(cls_or_self, type): if returns is None: if use_asset_returns: returns = cls_or_self.resolve_shortcut_attr( "asset_returns", log_returns=log_returns, daily_returns=daily_returns, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) else: returns = cls_or_self.resolve_shortcut_attr( "returns", log_returns=log_returns, daily_returns=daily_returns, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) if bm_returns is None or (isinstance(bm_returns, bool) and bm_returns): bm_returns = cls_or_self.resolve_shortcut_attr( "bm_returns", log_returns=log_returns, daily_returns=daily_returns, sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) elif isinstance(bm_returns, bool) and not bm_returns: bm_returns = None if freq is None: freq = cls_or_self.wrapper.freq if year_freq is None: year_freq = cls_or_self.year_freq defaults = merge_dicts(cls_or_self.returns_acc_defaults, defaults) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(returns, arg_name="returns") sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by) sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by) if daily_returns: freq = "D" if wrapper is not None: wrapper = wrapper.resolve(group_by=group_by) return returns.vbt.returns( wrapper=wrapper, bm_returns=bm_returns, log_returns=log_returns, freq=freq, year_freq=year_freq, defaults=defaults, sim_start=sim_start, sim_end=sim_end, **kwargs, ) @property def returns_acc(self) -> ReturnsAccessor: """`Portfolio.get_returns_acc` with default arguments.""" return self.get_returns_acc() @hybrid_method def get_qs( cls_or_self, returns: tp.Optional[tp.SeriesFrame] = None, use_asset_returns: bool = False, bm_returns: tp.Union[None, bool, tp.ArrayLike] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, freq: tp.Optional[tp.FrequencyLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> QSAdapterT: """Get quantstats adapter of type `vectorbtpro.returns.qs_adapter.QSAdapter`. `**kwargs` are passed to the adapter constructor.""" from vectorbtpro.returns.qs_adapter import QSAdapter returns_acc = cls_or_self.get_returns_acc( returns=returns, use_asset_returns=use_asset_returns, bm_returns=bm_returns, log_returns=log_returns, daily_returns=daily_returns, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, freq=freq, year_freq=year_freq, defaults=defaults, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) return QSAdapter(returns_acc, **kwargs) @property def qs(self) -> QSAdapterT: """`Portfolio.get_qs` with default arguments.""" return self.get_qs() # ############# Resolution ############# # @property def self_aliases(self) -> tp.Set[str]: """Names to associate with this object.""" return {"self", "portfolio", "pf"} def pre_resolve_attr(self, attr: str, final_kwargs: tp.KwargsLike = None) -> str: """Pre-process an attribute before resolution. Uses the following keys: * `use_asset_returns`: Whether to use `Portfolio.get_asset_returns` when resolving `returns` argument. * `trades_type`: Which trade type to use when resolving `trades` argument.""" if "use_asset_returns" in final_kwargs: if attr == "returns" and final_kwargs["use_asset_returns"]: attr = "asset_returns" if "trades_type" in final_kwargs: trades_type = final_kwargs["trades_type"] if isinstance(final_kwargs["trades_type"], str): trades_type = map_enum_fields(trades_type, enums.TradesType) if attr == "trades" and trades_type != self.trades_type: if trades_type == enums.TradesType.EntryTrades: attr = "entry_trades" elif trades_type == enums.TradesType.ExitTrades: attr = "exit_trades" else: attr = "positions" return attr def post_resolve_attr(self, attr: str, out: tp.Any, final_kwargs: tp.KwargsLike = None) -> str: """Post-process an object after resolution. Uses the following keys: * `incl_open`: Whether to include open trades/positions when resolving an argument that is an instance of `vectorbtpro.portfolio.trades.Trades`.""" if "incl_open" in final_kwargs: if isinstance(out, Trades) and not final_kwargs["incl_open"]: out = out.status_closed return out def resolve_shortcut_attr(self, attr_name: str, *args, **kwargs) -> tp.Any: """Resolve an attribute that may have shortcut properties. If `attr_name` has a prefix `get_`, checks whether the respective shortcut property can be called. This way, complex call hierarchies can utilize cacheable properties.""" if not attr_name.startswith("get_"): if "get_" + attr_name not in self.cls_dir or (len(args) == 0 and len(kwargs) == 0): if isinstance(getattr(type(self), attr_name), property): return getattr(self, attr_name) return getattr(self, attr_name)(*args, **kwargs) attr_name = "get_" + attr_name if len(args) == 0: naked_attr_name = attr_name[4:] prop_name = naked_attr_name _kwargs = dict(kwargs) if "free" in _kwargs: if _kwargs.pop("free"): prop_name = "free_" + naked_attr_name if "direction" in _kwargs: direction = map_enum_fields(_kwargs.pop("direction"), enums.Direction) if direction == enums.Direction.LongOnly: prop_name = "long_" + naked_attr_name elif direction == enums.Direction.ShortOnly: prop_name = "short_" + naked_attr_name if prop_name in self.cls_dir: prop = getattr(type(self), prop_name) options = getattr(prop, "options", {}) can_call_prop = True if "group_by" in _kwargs: group_by = _kwargs.pop("group_by") group_aware = options.get("group_aware", True) if group_aware: if self.wrapper.grouper.is_grouping_modified(group_by=group_by): can_call_prop = False else: group_by = _kwargs.pop("group_by") if self.wrapper.grouper.is_grouping_enabled(group_by=group_by): can_call_prop = False if can_call_prop: _kwargs.pop("jitted", None) _kwargs.pop("chunked", None) for k, v in get_func_kwargs(getattr(type(self), attr_name)).items(): if k in _kwargs and v is not _kwargs.pop(k): can_call_prop = False break if can_call_prop: if len(_kwargs) > 0: can_call_prop = False if can_call_prop: return getattr(self, prop_name) return getattr(self, attr_name)(*args, **kwargs) # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Portfolio.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.portfolio`.""" from vectorbtpro._settings import settings portfolio_stats_cfg = settings["portfolio"]["stats"] return merge_dicts( Analyzable.stats_defaults.__get__(self), dict(settings=dict(trades_type=self.trades_type)), portfolio_stats_cfg, ) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func="sim_start_index", tags="wrapper", ), end_index=dict( title="End Index", calc_func="sim_end_index", tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func="sim_duration", apply_to_timedelta=True, tags="wrapper", ), start_value=dict( title="Start Value", calc_func="init_value", tags="portfolio", ), min_value=dict( title="Min Value", calc_func="value.vbt.min", tags="portfolio", ), max_value=dict( title="Max Value", calc_func="value.vbt.max", tags="portfolio", ), end_value=dict( title="End Value", calc_func="final_value", tags="portfolio", ), cash_deposits=dict( title="Total Cash Deposits", calc_func="total_cash_deposits", check_has_cash_deposits=True, tags="portfolio", ), cash_earnings=dict( title="Total Cash Earnings", calc_func="total_cash_earnings", check_has_cash_earnings=True, tags="portfolio", ), total_return=dict( title="Total Return [%]", calc_func="total_return", post_calc_func=lambda self, out, settings: out * 100, tags="portfolio", ), bm_return=dict( title="Benchmark Return [%]", calc_func="bm_returns.vbt.returns.total", post_calc_func=lambda self, out, settings: out * 100, check_has_bm_returns=True, tags="portfolio", ), total_time_exposure=dict( title="Position Coverage [%]", calc_func="position_coverage", post_calc_func=lambda self, out, settings: out * 100, tags="portfolio", ), max_gross_exposure=dict( title="Max Gross Exposure [%]", calc_func="gross_exposure.vbt.max", post_calc_func=lambda self, out, settings: out * 100, tags="portfolio", ), max_dd=dict( title="Max Drawdown [%]", calc_func="drawdowns.max_drawdown", post_calc_func=lambda self, out, settings: -out * 100, tags=["portfolio", "drawdowns"], ), max_dd_duration=dict( title="Max Drawdown Duration", calc_func="drawdowns.max_duration", fill_wrap_kwargs=True, tags=["portfolio", "drawdowns", "duration"], ), total_orders=dict( title="Total Orders", calc_func="orders.count", tags=["portfolio", "orders"], ), total_fees_paid=dict( title="Total Fees Paid", calc_func="orders.fees.sum", tags=["portfolio", "orders"], ), total_trades=dict( title="Total Trades", calc_func="trades.count", incl_open=True, tags=["portfolio", "trades"], ), win_rate=dict( title="Win Rate [%]", calc_func="trades.win_rate", post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['portfolio', 'trades', *incl_open_tags]"), ), best_trade=dict( title="Best Trade [%]", calc_func="trades.returns.max", post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['portfolio', 'trades', *incl_open_tags]"), ), worst_trade=dict( title="Worst Trade [%]", calc_func="trades.returns.min", post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['portfolio', 'trades', *incl_open_tags]"), ), avg_winning_trade=dict( title="Avg Winning Trade [%]", calc_func="trades.winning.returns.mean", post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'winning']"), ), avg_losing_trade=dict( title="Avg Losing Trade [%]", calc_func="trades.losing.returns.mean", post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'losing']"), ), avg_winning_trade_duration=dict( title="Avg Winning Trade Duration", calc_func="trades.winning.duration.mean", apply_to_timedelta=True, tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'winning', 'duration']"), ), avg_losing_trade_duration=dict( title="Avg Losing Trade Duration", calc_func="trades.losing.duration.mean", apply_to_timedelta=True, tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'losing', 'duration']"), ), profit_factor=dict( title="Profit Factor", calc_func="trades.profit_factor", tags=RepEval("['portfolio', 'trades', *incl_open_tags]"), ), expectancy=dict( title="Expectancy", calc_func="trades.expectancy", tags=RepEval("['portfolio', 'trades', *incl_open_tags]"), ), sharpe_ratio=dict( title="Sharpe Ratio", calc_func="returns_acc.sharpe_ratio", check_has_freq=True, check_has_year_freq=True, tags=["portfolio", "returns"], ), calmar_ratio=dict( title="Calmar Ratio", calc_func="returns_acc.calmar_ratio", check_has_freq=True, check_has_year_freq=True, tags=["portfolio", "returns"], ), omega_ratio=dict( title="Omega Ratio", calc_func="returns_acc.omega_ratio", check_has_freq=True, check_has_year_freq=True, tags=["portfolio", "returns"], ), sortino_ratio=dict( title="Sortino Ratio", calc_func="returns_acc.sortino_ratio", check_has_freq=True, check_has_year_freq=True, tags=["portfolio", "returns"], ), ) ) @property def metrics(self) -> Config: return self._metrics def returns_stats( self, use_asset_returns: bool = False, bm_returns: tp.Union[None, bool, tp.ArrayLike] = None, log_returns: bool = False, daily_returns: bool = False, freq: tp.Optional[tp.FrequencyLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, chunked: tp.ChunkedOption = None, group_by: tp.GroupByLike = None, **kwargs, ) -> tp.SeriesFrame: """Compute various statistics on returns of this portfolio. See `Portfolio.returns_acc` and `vectorbtpro.returns.accessors.ReturnsAccessor.metrics`. `kwargs` will be passed to `vectorbtpro.returns.accessors.ReturnsAccessor.stats` method. If `bm_returns` is not set, uses `Portfolio.get_market_returns`.""" returns_acc = self.get_returns_acc( use_asset_returns=use_asset_returns, bm_returns=bm_returns, log_returns=log_returns, daily_returns=daily_returns, freq=freq, year_freq=year_freq, defaults=defaults, chunked=chunked, group_by=group_by, ) return getattr(returns_acc, "stats")(**kwargs) # ############# Plotting ############# # @hybrid_method def plot_orders( cls_or_self, column: tp.Optional[tp.Label] = None, orders: tp.Optional[Drawdowns] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, **kwargs, ) -> tp.BaseFigure: """Plot one column of orders. `**kwargs` are passed to `vectorbtpro.portfolio.orders.Orders.plot`.""" if not isinstance(cls_or_self, type): if orders is None: orders = cls_or_self.resolve_shortcut_attr( "orders", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, ) else: checks.assert_not_none(orders, arg_name="orders") fig = orders.plot(column=column, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=False, xref=xref, ) return fig @hybrid_method def plot_trades( cls_or_self, column: tp.Optional[tp.Label] = None, trades: tp.Optional[Drawdowns] = None, trades_type: tp.Optional[tp.Union[str, int]] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, xref: str = "x", yref: str = "y", **kwargs, ) -> tp.BaseFigure: """Plot one column of trades. `**kwargs` are passed to `vectorbtpro.portfolio.trades.Trades.plot`.""" if not isinstance(cls_or_self, type): if trades is None: trades = cls_or_self.resolve_shortcut_attr( "trades", trades_type=trades_type, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, ) else: checks.assert_not_none(trades, arg_name="trades") fig = trades.plot(column=column, xref=xref, yref=yref, **kwargs) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=False, xref=xref, ) return fig @hybrid_method def plot_trade_pnl( cls_or_self, column: tp.Optional[tp.Label] = None, trades: tp.Optional[Drawdowns] = None, trades_type: tp.Optional[tp.Union[str, int]] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, pct_scale: bool = False, xref: str = "x", yref: str = "y", **kwargs, ) -> tp.BaseFigure: """Plot one column of trade P&L. `**kwargs` are passed to `vectorbtpro.portfolio.trades.Trades.plot_pnl`.""" if not isinstance(cls_or_self, type): if trades is None: trades = cls_or_self.resolve_shortcut_attr( "trades", trades_type=trades_type, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, ) else: checks.assert_not_none(trades, arg_name="trades") fig = trades.plot_pnl(column=column, pct_scale=pct_scale, xref=xref, yref=yref, **kwargs) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=False, xref=xref, ) return fig @hybrid_method def plot_trade_signals( cls_or_self, column: tp.Optional[tp.Label] = None, entry_trades: tp.Optional[EntryTrades] = None, exit_trades: tp.Optional[ExitTrades] = None, positions: tp.Optional[Positions] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, plot_positions: tp.Union[bool, str] = "zones", long_entry_trace_kwargs: tp.KwargsLike = None, short_entry_trace_kwargs: tp.KwargsLike = None, long_exit_trace_kwargs: tp.KwargsLike = None, short_exit_trace_kwargs: tp.KwargsLike = None, long_shape_kwargs: tp.KwargsLike = None, short_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of trade signals. Markers and shapes are colored by trade direction (green = long, red = short).""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if entry_trades is None: entry_trades = cls_or_self.resolve_shortcut_attr( "entry_trades", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=False, ) if exit_trades is None: exit_trades = cls_or_self.resolve_shortcut_attr( "exit_trades", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=False, ) if positions is None: positions = cls_or_self.resolve_shortcut_attr( "positions", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=False, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(entry_trades, arg_name="entry_trades") checks.assert_not_none(exit_trades, arg_name="exit_trades") checks.assert_not_none(positions, arg_name="positions") if wrapper is None: wrapper = entry_trades.wrapper fig = entry_trades.plot_signals( column=column, long_entry_trace_kwargs=long_entry_trace_kwargs, short_entry_trace_kwargs=short_entry_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **kwargs, ) fig = exit_trades.plot_signals( column=column, plot_ohlc=False, plot_close=False, long_exit_trace_kwargs=long_exit_trace_kwargs, short_exit_trace_kwargs=short_exit_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" if isinstance(plot_positions, bool): if plot_positions: plot_positions = "zones" else: plot_positions = None if plot_positions is not None: if plot_positions.lower() == "zones": long_shape_kwargs = merge_dicts( dict(fillcolor=plotting_cfg["contrast_color_schema"]["green"]), long_shape_kwargs, ) short_shape_kwargs = merge_dicts( dict(fillcolor=plotting_cfg["contrast_color_schema"]["red"]), short_shape_kwargs, ) elif plot_positions.lower() == "lines": base_shape_kwargs = dict( type="line", line=dict(dash="dot"), xref=Rep("xref"), yref=Rep("yref"), x0=Rep("start_index"), x1=Rep("end_index"), y0=RepFunc(lambda record: record["entry_price"]), y1=RepFunc(lambda record: record["exit_price"]), opacity=0.75, ) long_shape_kwargs = atomic_dict( merge_dicts( base_shape_kwargs, dict(line=dict(color=plotting_cfg["contrast_color_schema"]["green"])), long_shape_kwargs, ) ) short_shape_kwargs = atomic_dict( merge_dicts( base_shape_kwargs, dict(line=dict(color=plotting_cfg["contrast_color_schema"]["red"])), short_shape_kwargs, ) ) else: raise ValueError(f"Invalid plot_positions: '{plot_positions}'") fig = positions.direction_long.plot_shapes( column=column, plot_ohlc=False, plot_close=False, shape_kwargs=long_shape_kwargs, add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) fig = positions.direction_short.plot_shapes( column=column, plot_ohlc=False, plot_close=False, shape_kwargs=short_shape_kwargs, add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=False, xref=xref, ) return fig @hybrid_method def plot_cash_flow( cls_or_self, column: tp.Optional[tp.Label] = None, free: bool = False, cash_flow: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of cash flow. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if cash_flow is None: cash_flow = cls_or_self.resolve_shortcut_attr( "cash_flow", free=free, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(cash_flow, arg_name="cash_flow") checks.assert_not_none(wrapper, arg_name="wrapper") cash_flow = wrapper.select_col_from_obj(cash_flow, column=column, group_by=group_by) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["green"], shape=line_shape, ), name="Cash", ) ), kwargs, ) fig = cash_flow.vbt.lineplot(**kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=0.0, x1=x_domain[1], y1=0.0, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_cash( cls_or_self, column: tp.Optional[tp.Label] = None, free: bool = False, init_cash: tp.Optional[tp.MaybeSeries] = None, cash: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of cash balance. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if init_cash is None: init_cash = cls_or_self.resolve_shortcut_attr( "init_cash", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if cash is None: cash = cls_or_self.resolve_shortcut_attr( "cash", free=free, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_cash, arg_name="init_cash") checks.assert_not_none(cash, arg_name="cash") checks.assert_not_none(wrapper, arg_name="wrapper") init_cash = wrapper.select_col_from_obj(init_cash, column=column, group_by=group_by) cash = wrapper.select_col_from_obj(cash, column=column, group_by=group_by) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["green"], shape=line_shape, ), name="Cash", ), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["green"], 0.3), line=dict(shape=line_shape), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3), line=dict(shape=line_shape), ), other_trace_kwargs="hidden", ), kwargs, ) fig = cash.vbt.plot_against(init_cash, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=init_cash, x1=x_domain[1], y1=init_cash, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_asset_flow( cls_or_self, column: tp.Optional[tp.Label] = None, direction: tp.Union[str, int] = "both", asset_flow: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column of asset flow. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if asset_flow is None: asset_flow = cls_or_self.resolve_shortcut_attr( "asset_flow", direction=direction, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(asset_flow, arg_name="asset_flow") checks.assert_not_none(wrapper, arg_name="wrapper") asset_flow = wrapper.select_col_from_obj(asset_flow, column=column, group_by=False) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["blue"], shape=line_shape, ), name="Assets", ) ), kwargs, ) fig = asset_flow.vbt.lineplot(**kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=0, x1=x_domain[1], y1=0, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=False, xref=xref, ) return fig @hybrid_method def plot_assets( cls_or_self, column: tp.Optional[tp.Label] = None, direction: tp.Union[str, int] = "both", assets: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column of assets. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if assets is None: assets = cls_or_self.resolve_shortcut_attr( "assets", direction=direction, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(assets, arg_name="assets") checks.assert_not_none(wrapper, arg_name="wrapper") assets = wrapper.select_col_from_obj(assets, column=column, group_by=False) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["blue"], shape=line_shape, ), name="Assets", ), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["blue"], 0.3), line=dict(shape=line_shape), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3), line=dict(shape=line_shape), ), other_trace_kwargs="hidden", ), kwargs, ) fig = assets.vbt.plot_against(0, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=0.0, x1=x_domain[1], y1=0.0, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=False, xref=xref, ) return fig @hybrid_method def plot_asset_value( cls_or_self, column: tp.Optional[tp.Label] = None, direction: tp.Union[str, int] = "both", asset_value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of asset value. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if asset_value is None: asset_value = cls_or_self.resolve_shortcut_attr( "asset_value", direction=direction, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(asset_value, arg_name="asset_value") checks.assert_not_none(wrapper, arg_name="wrapper") asset_value = wrapper.select_col_from_obj(asset_value, column=column, group_by=group_by) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["purple"], shape=line_shape, ), name="Value", ), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["purple"], 0.3), line=dict(shape=line_shape), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3), line=dict(shape=line_shape), ), other_trace_kwargs="hidden", ), kwargs, ) fig = asset_value.vbt.plot_against(0, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=0.0, x1=x_domain[1], y1=0.0, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_value( cls_or_self, column: tp.Optional[tp.Label] = None, init_value: tp.Optional[tp.MaybeSeries] = None, value: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of value. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if value is None: value = cls_or_self.resolve_shortcut_attr( "value", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_value, arg_name="init_value") checks.assert_not_none(value, arg_name="value") checks.assert_not_none(wrapper, arg_name="wrapper") init_value = wrapper.select_col_from_obj(init_value, column=column, group_by=group_by) value = wrapper.select_col_from_obj(value, column=column, group_by=group_by) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["purple"], ), name="Value", ), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["purple"], 0.3), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["red"], 0.3), ), other_trace_kwargs="hidden", ), kwargs, ) fig = value.vbt.plot_against(init_value, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=init_value, x1=x_domain[1], y1=init_value, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_cumulative_returns( cls_or_self, column: tp.Optional[tp.Label] = None, returns_acc: tp.Optional[ReturnsAccessor] = None, use_asset_returns: bool = False, bm_returns: tp.Union[None, bool, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, pct_scale: bool = False, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of cumulative returns. If `bm_returns` is None, will use `Portfolio.get_market_returns`. `**kwargs` are passed to `vectorbtpro.returns.accessors.ReturnsSRAccessor.plot_cumulative`.""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if returns_acc is None: returns_acc = cls_or_self.resolve_shortcut_attr( "returns_acc", use_asset_returns=use_asset_returns, bm_returns=bm_returns, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) else: checks.assert_not_none(returns_acc, arg_name="returns_acc") kwargs = merge_dicts( dict( main_kwargs=dict( trace_kwargs=dict(name="Value"), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["purple"], 0.3), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["red"], 0.3), ), ), hline_shape_kwargs=dict( type="line", line=dict( color="gray", dash="dash", ), ), ), kwargs, ) fig = returns_acc.plot_cumulative( column=column, fit_sim_range=fit_sim_range, pct_scale=pct_scale, **kwargs, ) return fig @hybrid_method def plot_drawdowns( cls_or_self, column: tp.Optional[tp.Label] = None, drawdowns: tp.Optional[Drawdowns] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, xref: str = "x", yref: str = "y", **kwargs, ) -> tp.BaseFigure: """Plot one column or group of drawdowns. `**kwargs` are passed to `vectorbtpro.generic.drawdowns.Drawdowns.plot`.""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if drawdowns is None: drawdowns = cls_or_self.resolve_shortcut_attr( "drawdowns", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) else: checks.assert_not_none(drawdowns, arg_name="drawdowns") kwargs = merge_dicts( dict( close_trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["purple"], ), name="Value", ), ), kwargs, ) fig = drawdowns.plot(column=column, xref=xref, yref=yref, **kwargs) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_underwater( cls_or_self, column: tp.Optional[tp.Label] = None, init_value: tp.Optional[tp.MaybeSeries] = None, returns_acc: tp.Optional[ReturnsAccessor] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, pct_scale: bool = True, xref: str = "x", yref: str = "y", hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of underwater. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if init_value is None: init_value = cls_or_self.resolve_shortcut_attr( "init_value", sim_start=sim_start if rec_sim_range else None, sim_end=sim_end if rec_sim_range else None, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if returns_acc is None: returns_acc = cls_or_self.resolve_shortcut_attr( "returns_acc", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(init_value, arg_name="init_value") checks.assert_not_none(returns_acc, arg_name="returns_acc") checks.assert_not_none(wrapper, arg_name="wrapper") if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") drawdown = returns_acc.drawdown() drawdown = wrapper.select_col_from_obj(drawdown, column=column, group_by=group_by) if not pct_scale: cumulative_returns = returns_acc.cumulative() cumret = wrapper.select_col_from_obj(cumulative_returns, column=column, group_by=group_by) init_value = wrapper.select_col_from_obj(init_value, column=column, group_by=group_by) drawdown = cumret * init_value * drawdown / (1 + drawdown) default_kwargs = dict( trace_kwargs=dict( line=dict(color=plotting_cfg["color_schema"]["red"]), fillcolor=adjust_opacity(plotting_cfg["color_schema"]["red"], 0.3), fill="tozeroy", name="Drawdown", ) ) if pct_scale: yaxis = "yaxis" + yref[1:] default_kwargs[yaxis] = dict(tickformat=".2%") kwargs = merge_dicts(default_kwargs, kwargs) fig = drawdown.vbt.lineplot(**kwargs) x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=0, x1=x_domain[1], y1=0, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_gross_exposure( cls_or_self, column: tp.Optional[tp.Label] = None, direction: tp.Union[str, int] = "both", gross_exposure: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of gross exposure. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if gross_exposure is None: gross_exposure = cls_or_self.resolve_shortcut_attr( "gross_exposure", direction=direction, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(gross_exposure, arg_name="gross_exposure") checks.assert_not_none(wrapper, arg_name="wrapper") gross_exposure = wrapper.select_col_from_obj(gross_exposure, column=column, group_by=group_by) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["pink"], shape=line_shape, ), name="Exposure", ), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3), line=dict(shape=line_shape), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.3), line=dict(shape=line_shape), ), other_trace_kwargs="hidden", ), kwargs, ) fig = gross_exposure.vbt.plot_against(1, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=1, x1=x_domain[1], y1=1, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_net_exposure( cls_or_self, column: tp.Optional[tp.Label] = None, net_exposure: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, line_shape: str = "hv", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column or group of net exposure. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] if not isinstance(cls_or_self, type): if net_exposure is None: net_exposure = cls_or_self.resolve_shortcut_attr( "net_exposure", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(net_exposure, arg_name="net_exposure") checks.assert_not_none(wrapper, arg_name="wrapper") net_exposure = wrapper.select_col_from_obj(net_exposure, column=column, group_by=group_by) kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["pink"], shape=line_shape, ), name="Exposure", ), pos_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3), line=dict(shape=line_shape), ), neg_trace_kwargs=dict( fillcolor=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.3), line=dict(shape=line_shape), ), other_trace_kwargs="hidden", ), kwargs, ) fig = net_exposure.vbt.plot_against(1, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=1, x1=x_domain[1], y1=1, ), hline_shape_kwargs, ) ) if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @hybrid_method def plot_allocations( cls_or_self, column: tp.Optional[tp.Label] = None, allocations: tp.Optional[tp.SeriesFrame] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, fit_sim_range: bool = True, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, line_shape: str = "hv", line_visible: bool = True, colorway: tp.Union[None, str, tp.Sequence[str]] = "Vivid", xref: tp.Optional[str] = None, yref: tp.Optional[str] = None, **kwargs, ) -> tp.BaseFigure: """Plot one group of allocations.""" if not isinstance(cls_or_self, type): if allocations is None: allocations = cls_or_self.resolve_shortcut_attr( "allocations", sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, wrapper=wrapper, group_by=group_by, ) if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(allocations, arg_name="allocations") checks.assert_not_none(wrapper, arg_name="wrapper") allocations = wrapper.select_col_from_obj(allocations, column=column, obj_ungrouped=True, group_by=group_by) if wrapper.grouper.is_grouped(group_by=group_by): group_names = wrapper.grouper.get_index(group_by=group_by).names allocations = allocations.vbt.drop_levels(group_names, strict=False) if isinstance(allocations, pd.Series) and allocations.name is None: allocations.name = "Allocation" fig = allocations.vbt.areaplot(line_shape=line_shape, line_visible=line_visible, colorway=colorway, **kwargs) if xref is None: xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" if yref is None: yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" if fit_sim_range: fig = cls_or_self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, wrapper=wrapper, group_by=group_by, xref=xref, ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Portfolio.plot`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.portfolio`.""" from vectorbtpro._settings import settings portfolio_plots_cfg = settings["portfolio"]["plots"] return merge_dicts( Analyzable.plots_defaults.__get__(self), dict(settings=dict(trades_type=self.trades_type)), portfolio_plots_cfg, ) _subplots: tp.ClassVar[Config] = HybridConfig( dict( orders=dict( title="Orders", yaxis_kwargs=dict(title="Price"), check_is_not_grouped=True, plot_func="plot_orders", pass_add_trace_kwargs=True, tags=["portfolio", "orders"], ), trades=dict( title="Trades", yaxis_kwargs=dict(title="Price"), check_is_not_grouped=True, plot_func="plot_trades", pass_add_trace_kwargs=True, tags=["portfolio", "trades"], ), trade_pnl=dict( title="Trade PnL", yaxis_kwargs=dict(title="PnL"), check_is_not_grouped=True, plot_func="plot_trade_pnl", pct_scale=True, pass_add_trace_kwargs=True, tags=["portfolio", "trades"], ), trade_signals=dict( title="Trade Signals", yaxis_kwargs=dict(title="Price"), check_is_not_grouped=True, plot_func="plot_trade_signals", tags=["portfolio", "trades"], ), cash_flow=dict( title="Cash Flow", yaxis_kwargs=dict(title="Amount"), plot_func="plot_cash_flow", pass_add_trace_kwargs=True, tags=["portfolio", "cash"], ), cash=dict( title="Cash", yaxis_kwargs=dict(title="Amount"), plot_func="plot_cash", pass_add_trace_kwargs=True, tags=["portfolio", "cash"], ), asset_flow=dict( title="Asset Flow", yaxis_kwargs=dict(title="Amount"), check_is_not_grouped=True, plot_func="plot_asset_flow", pass_add_trace_kwargs=True, tags=["portfolio", "assets"], ), assets=dict( title="Assets", yaxis_kwargs=dict(title="Amount"), check_is_not_grouped=True, plot_func="plot_assets", pass_add_trace_kwargs=True, tags=["portfolio", "assets"], ), asset_value=dict( title="Asset Value", yaxis_kwargs=dict(title="Value"), plot_func="plot_asset_value", pass_add_trace_kwargs=True, tags=["portfolio", "assets", "value"], ), value=dict( title="Value", yaxis_kwargs=dict(title="Value"), plot_func="plot_value", pass_add_trace_kwargs=True, tags=["portfolio", "value"], ), cumulative_returns=dict( title="Cumulative Returns", yaxis_kwargs=dict(title="Cumulative return"), plot_func="plot_cumulative_returns", pass_hline_shape_kwargs=True, pass_add_trace_kwargs=True, pass_xref=True, pass_yref=True, tags=["portfolio", "returns"], ), drawdowns=dict( title="Drawdowns", yaxis_kwargs=dict(title="Value"), plot_func="plot_drawdowns", pass_add_trace_kwargs=True, pass_xref=True, pass_yref=True, tags=["portfolio", "value", "drawdowns"], ), underwater=dict( title="Underwater", yaxis_kwargs=dict(title="Drawdown"), plot_func="plot_underwater", pass_add_trace_kwargs=True, tags=["portfolio", "value", "drawdowns"], ), gross_exposure=dict( title="Gross Exposure", yaxis_kwargs=dict(title="Exposure"), plot_func="plot_gross_exposure", pass_add_trace_kwargs=True, tags=["portfolio", "exposure"], ), net_exposure=dict( title="Net Exposure", yaxis_kwargs=dict(title="Exposure"), plot_func="plot_net_exposure", pass_add_trace_kwargs=True, tags=["portfolio", "exposure"], ), allocations=dict( title="Allocations", yaxis_kwargs=dict(title="Allocation"), plot_func="plot_allocations", pass_add_trace_kwargs=True, tags=["portfolio", "allocations"], ), ) ) plot = Analyzable.plots @property def subplots(self) -> Config: return self._subplots # ############# Docs ############# # @classmethod def build_in_output_config_doc(cls, source_cls: tp.Optional[type] = None) -> str: """Build in-output config documentation.""" if source_cls is None: source_cls = Portfolio return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "in_output_config").__doc__)).substitute( {"in_output_config": cls.in_output_config.prettify(), "cls_name": cls.__name__}, ) @classmethod def override_in_output_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: """Call this method on each subclass that overrides `Portfolio.in_output_config`.""" __pdoc__[cls.__name__ + ".in_output_config"] = cls.build_in_output_config_doc(source_cls=source_cls) Portfolio.override_in_output_config_doc(__pdoc__) Portfolio.override_metrics_doc(__pdoc__) Portfolio.override_subplots_doc(__pdoc__) __pdoc__["Portfolio.plot"] = "See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`." PF = Portfolio """Shortcut for `Portfolio`.""" __pdoc__["PF"] = False # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Functions for working with call sequence arrays.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.portfolio.enums import CallSeqType from vectorbtpro.registries.jit_registry import register_jitted __all__ = [] @register_jitted(cache=True) def shuffle_call_seq_nb(call_seq: tp.Array2d, group_lens: tp.GroupLens) -> None: """Shuffle the call sequence array.""" from_col = 0 for group in range(len(group_lens)): to_col = from_col + group_lens[group] for i in range(call_seq.shape[0]): np.random.shuffle(call_seq[i, from_col:to_col]) from_col = to_col @register_jitted(cache=True) def build_call_seq_nb( target_shape: tp.Shape, group_lens: tp.GroupLens, call_seq_type: int = CallSeqType.Default, ) -> tp.Array2d: """Build a new call sequence array.""" if call_seq_type == CallSeqType.Reversed: out = np.full(target_shape[1], 1, dtype=int_) out[np.cumsum(group_lens)[1:] - group_lens[1:] - 1] -= group_lens[1:] out = np.cumsum(out[::-1])[::-1] - 1 out = out * np.ones((target_shape[0], 1), dtype=int_) return out out = np.full(target_shape[1], 1, dtype=int_) out[np.cumsum(group_lens)[:-1]] -= group_lens[:-1] out = np.cumsum(out) - 1 out = out * np.ones((target_shape[0], 1), dtype=int_) if call_seq_type == CallSeqType.Random: shuffle_call_seq_nb(out, group_lens) return out def require_call_seq(call_seq: tp.Array2d) -> tp.Array2d: """Force the call sequence array to pass our requirements.""" return np.require(call_seq, dtype=int_, requirements=["A", "O", "W", "F"]) def build_call_seq( target_shape: tp.Shape, group_lens: tp.GroupLens, call_seq_type: int = CallSeqType.Default, ) -> tp.Array2d: """Not compiled but faster version of `build_call_seq_nb`.""" call_seq = np.full(target_shape[1], 1, dtype=int_) if call_seq_type == CallSeqType.Reversed: call_seq[np.cumsum(group_lens)[1:] - group_lens[1:] - 1] -= group_lens[1:] call_seq = np.cumsum(call_seq[::-1])[::-1] - 1 else: call_seq[np.cumsum(group_lens[:-1])] -= group_lens[:-1] call_seq = np.cumsum(call_seq) - 1 call_seq = np.broadcast_to(call_seq, target_shape) if call_seq_type == CallSeqType.Random: call_seq = require_call_seq(call_seq) shuffle_call_seq_nb(call_seq, group_lens) return require_call_seq(call_seq) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Extensions for chunking of portfolio.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.merging import concat_arrays, column_stack_arrays from vectorbtpro.portfolio.enums import SimulationOutput from vectorbtpro.records.chunking import merge_records from vectorbtpro.utils.chunking import ChunkMeta, ArraySlicer from vectorbtpro.utils.config import ReadonlyConfig from vectorbtpro.utils.template import Rep __all__ = [] def get_init_cash_slicer(ann_args: tp.AnnArgs) -> ArraySlicer: """Get slicer for `init_cash` based on cash sharing.""" cash_sharing = ann_args["cash_sharing"]["value"] if cash_sharing: return base_ch.FlexArraySlicer() return base_ch.flex_1d_array_gl_slicer def get_cash_deposits_slicer(ann_args: tp.AnnArgs) -> ArraySlicer: """Get slicer for `cash_deposits` based on cash sharing.""" cash_sharing = ann_args["cash_sharing"]["value"] if cash_sharing: return base_ch.FlexArraySlicer(axis=1) return base_ch.flex_array_gl_slicer def in_outputs_merge_func( results: tp.List[SimulationOutput], chunk_meta: tp.Iterable[ChunkMeta], ann_args: tp.AnnArgs, mapper: base_ch.GroupLensMapper, ): """Merge chunks of in-output objects. Concatenates 1-dim arrays, stacks columns of 2-dim arrays, and fixes and concatenates record arrays using `vectorbtpro.records.chunking.merge_records`. Other objects will throw an error.""" in_outputs = dict() for k, v in results[0].in_outputs._asdict().items(): if v is None: in_outputs[k] = None continue if not isinstance(v, np.ndarray): raise TypeError(f"Cannot merge in-output object '{k}' of type {type(v)}") if v.ndim == 2: in_outputs[k] = column_stack_arrays([getattr(r.in_outputs, k) for r in results]) elif v.ndim == 1: if v.dtype.fields is None: in_outputs[k] = np.concatenate([getattr(r.in_outputs, k) for r in results]) else: records = [getattr(r.in_outputs, k) for r in results] in_outputs[k] = merge_records(records, chunk_meta, ann_args=ann_args, mapper=mapper) else: raise ValueError(f"Cannot merge in-output object '{k}' with number of dimensions {v.ndim}") return type(results[0].in_outputs)(**in_outputs) def merge_sim_outs( results: tp.List[SimulationOutput], chunk_meta: tp.Iterable[ChunkMeta], ann_args: tp.AnnArgs, mapper: base_ch.GroupLensMapper, in_outputs_merge_func: tp.Callable = in_outputs_merge_func, **kwargs, ) -> SimulationOutput: """Merge chunks of `vectorbtpro.portfolio.enums.SimulationOutput` instances. If `SimulationOutput.in_outputs` is not None, must provide `in_outputs_merge_func` or similar.""" order_records = [r.order_records for r in results] order_records = merge_records(order_records, chunk_meta, ann_args=ann_args, mapper=mapper) log_records = [r.log_records for r in results] log_records = merge_records(log_records, chunk_meta, ann_args=ann_args, mapper=mapper) target_shape = ann_args["target_shape"]["value"] if results[0].cash_deposits.shape == target_shape: cash_deposits = column_stack_arrays([r.cash_deposits for r in results]) else: cash_deposits = results[0].cash_deposits if results[0].cash_earnings.shape == target_shape: cash_earnings = column_stack_arrays([r.cash_earnings for r in results]) else: cash_earnings = results[0].cash_earnings if results[0].call_seq is not None: call_seq = column_stack_arrays([r.call_seq for r in results]) else: call_seq = None if results[0].in_outputs is not None: in_outputs = in_outputs_merge_func(results, chunk_meta, ann_args, mapper, **kwargs) else: in_outputs = None if results[0].sim_start is not None: sim_start = concat_arrays([r.sim_start for r in results]) else: sim_start = None if results[0].sim_end is not None: sim_end = concat_arrays([r.sim_end for r in results]) else: sim_end = None return SimulationOutput( order_records=order_records, log_records=log_records, cash_deposits=cash_deposits, cash_earnings=cash_earnings, call_seq=call_seq, in_outputs=in_outputs, sim_start=sim_start, sim_end=sim_end, ) merge_sim_outs_config = ReadonlyConfig( dict( merge_func=merge_sim_outs, merge_kwargs=dict( chunk_meta=Rep("chunk_meta"), ann_args=Rep("ann_args"), mapper=base_ch.group_lens_mapper, ), ) ) """Config for merging using `merge_sim_outs`.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class decorators for portfolio.""" from vectorbtpro import _typing as tp from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.utils import checks from vectorbtpro.utils.config import Config, resolve_dict from vectorbtpro.utils.decorators import cacheable_property, cached_property from vectorbtpro.utils.parsing import get_func_arg_names __all__ = [] def attach_returns_acc_methods(config: Config) -> tp.ClassWrapper: """Class decorator to attach returns accessor methods. `config` must contain target method names (keys) and settings (values) with the following keys: * `source_name`: Name of the source method. Defaults to the target name. * `docstring`: Method docstring. The class must be a subclass of `vectorbtpro.portfolio.base.Portfolio`.""" def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: checks.assert_subclass_of(cls, "Portfolio") for target_name, settings in config.items(): source_name = settings.get("source_name", target_name) docstring = settings.get("docstring", f"See `vectorbtpro.returns.accessors.ReturnsAccessor.{source_name}`.") def new_method( self, *, returns: tp.Optional[tp.SeriesFrame] = None, use_asset_returns: bool = False, bm_returns: tp.Union[None, bool, tp.ArrayLike] = None, log_returns: bool = False, daily_returns: bool = False, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, rec_sim_range: bool = False, freq: tp.Optional[tp.FrequencyLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, _source_name: str = source_name, **kwargs, ) -> tp.Any: returns_acc = self.get_returns_acc( returns=returns, use_asset_returns=use_asset_returns, bm_returns=bm_returns, log_returns=log_returns, daily_returns=daily_returns, sim_start=sim_start, sim_end=sim_end, rec_sim_range=rec_sim_range, freq=freq, year_freq=year_freq, defaults=defaults, jitted=jitted, chunked=chunked, wrapper=wrapper, group_by=group_by, ) ret_method = getattr(returns_acc, _source_name) if "jitted" in get_func_arg_names(ret_method): kwargs["jitted"] = jitted return ret_method(**kwargs) new_method.__name__ = "get_" + target_name new_method.__module__ = cls.__module__ new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}" new_method.__doc__ = docstring setattr(cls, new_method.__name__, new_method) return cls return wrapper def attach_shortcut_properties(config: Config) -> tp.ClassWrapper: """Class decorator to attach shortcut properties. `config` must contain target property names (keys) and settings (values) with the following keys: * `method_name`: Name of the source method. Defaults to the target name prepended with the prefix `get_`. * `use_in_outputs`: Whether the property can return an in-place output. Defaults to True. * `method_kwargs`: Keyword arguments passed to the source method. Defaults to None. * `decorator`: Defaults to `vectorbtpro.utils.decorators.cached_property` for object types 'records' and 'red_array'. Otherwise, to `vectorbtpro.utils.decorators.cacheable_property`. * `docstring`: Method docstring. * Other keyword arguments are passed to the decorator and can include settings for wrapping, indexing, resampling, stacking, etc. The class must be a subclass of `vectorbtpro.portfolio.base.Portfolio`.""" def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: checks.assert_subclass_of(cls, "Portfolio") for target_name, settings in config.items(): settings = dict(settings) if target_name.startswith("get_"): raise ValueError(f"Property names cannot have prefix 'get_' ('{target_name}')") method_name = settings.pop("method_name", "get_" + target_name) use_in_outputs = settings.pop("use_in_outputs", True) method_kwargs = settings.pop("method_kwargs", None) method_kwargs = resolve_dict(method_kwargs) decorator = settings.pop("decorator", None) if decorator is None: if settings.get("obj_type", "array") in ("red_array", "records"): decorator = cached_property else: decorator = cacheable_property docstring = settings.pop("docstring", None) if docstring is None: if len(method_kwargs) == 0: docstring = f"`{cls.__name__}.{method_name}` with default arguments." else: docstring = f"`{cls.__name__}.{method_name}` with arguments `{method_kwargs}`." def new_prop( self, _method_name: tp.Optional[str] = method_name, _target_name: str = target_name, _use_in_outputs: bool = use_in_outputs, _method_kwargs: tp.Kwargs = method_kwargs, ) -> tp.Any: if _use_in_outputs and self.use_in_outputs and self.in_outputs is not None: try: out = self.get_in_output(_target_name) if out is not None: return out except AttributeError: pass if _method_name is None: raise ValueError(f"Field '{_target_name}' must be prefilled") return getattr(self, _method_name)(**_method_kwargs) new_prop.__name__ = target_name new_prop.__module__ = cls.__module__ new_prop.__qualname__ = f"{cls.__name__}.{new_prop.__name__}" new_prop.__doc__ = docstring setattr(cls, new_prop.__name__, decorator(new_prop, **settings)) return cls return wrapper # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for portfolio. Defines enums and other schemas for `vectorbtpro.portfolio`.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.utils.formatting import prettify __pdoc__all__ = __all__ = [ "RejectedOrderError", "PriceType", "ValPriceType", "InitCashMode", "CallSeqType", "PendingConflictMode", "AccumulationMode", "ConflictMode", "DirectionConflictMode", "OppositeEntryMode", "DeltaFormat", "TimeDeltaFormat", "StopLadderMode", "StopEntryPrice", "StopExitPrice", "StopExitType", "StopUpdateMode", "SizeType", "Direction", "LeverageMode", "PriceAreaVioMode", "OrderStatus", "OrderStatusInfo", "status_info_desc", "OrderSide", "OrderType", "LimitOrderPrice", "TradeDirection", "TradeStatus", "TradesType", "OrderPriceStatus", "PositionFeature", "PriceArea", "NoPriceArea", "AccountState", "ExecState", "SimulationOutput", "SimulationContext", "GroupContext", "RowContext", "SegmentContext", "OrderContext", "PostOrderContext", "FlexOrderContext", "Order", "NoOrder", "OrderResult", "SignalSegmentContext", "SignalContext", "PostSignalContext", "FSInOutputs", "FOInOutputs", "order_fields", "order_dt", "fs_order_fields", "fs_order_dt", "trade_fields", "trade_dt", "log_fields", "log_dt", "alloc_range_fields", "alloc_range_dt", "alloc_point_fields", "alloc_point_dt", "main_info_fields", "main_info_dt", "limit_info_fields", "limit_info_dt", "sl_info_fields", "sl_info_dt", "tsl_info_fields", "tsl_info_dt", "tp_info_fields", "tp_info_dt", "time_info_fields", "time_info_dt", ] __pdoc__ = {} # ############# Errors ############# # class RejectedOrderError(Exception): """Rejected order error.""" pass # ############# Enums ############# # class PriceTypeT(tp.NamedTuple): Open: int = -np.inf Close: int = np.inf NextOpen: int = -1 NextClose: int = -2 NextValidOpen: int = -3 NextValidClose: int = -4 PriceType = PriceTypeT() """_""" __pdoc__[ "PriceType" ] = f"""Price type. ```python {prettify(PriceType)} ``` Attributes: Open: Opening price. Will be substituted by `-np.inf`. Close: Closing price. Will be substituted by `np.inf`. NextOpen: Next opening price. Will be substituted by `-np.inf` and `from_ago` will be set to 1. NextClose: Next closing price. Will be substituted by `np.inf` and `from_ago` will be set to 1. NextValidOpen: Next valid (non-NA) opening price. Will be substituted by `-np.inf` and `from_ago` will be set to the distance to the previous valid value. NextValidClose: Next valid (non-NA) closing price. Will be substituted by `np.inf` and `from_ago` will be set to the distance to the previous valid value. """ class ValPriceTypeT(tp.NamedTuple): Latest: int = -np.inf Price: int = np.inf ValPriceType = ValPriceTypeT() """_""" __pdoc__[ "ValPriceType" ] = f"""Asset valuation price type. ```python {prettify(ValPriceType)} ``` Attributes: Latest: Latest price. Will be substituted by `-np.inf`. Price: Order price. Will be substituted by `np.inf`. """ class InitCashModeT(tp.NamedTuple): Auto: int = -1 AutoAlign: int = -2 InitCashMode = InitCashModeT() """_""" __pdoc__[ "InitCashMode" ] = f"""Initial cash mode. ```python {prettify(InitCashMode)} ``` Attributes: Auto: Initial cash is infinite within simulation, and then set to the total cash spent. AutoAlign: Initial cash is set to the total cash spent across all columns. """ class CallSeqTypeT(tp.NamedTuple): Default: int = 0 Reversed: int = 1 Random: int = 2 Auto: int = 3 CallSeqType = CallSeqTypeT() """_""" __pdoc__[ "CallSeqType" ] = f"""Call sequence type. ```python {prettify(CallSeqType)} ``` Attributes: Default: Place calls from left to right. Reversed: Place calls from right to left. Random: Place calls randomly. Auto: Place calls dynamically based on order value. """ class PendingConflictModeT(tp.NamedTuple): KeepIgnore: int = 0 KeepExecute: int = 1 CancelIgnore: int = 2 CancelExecute: int = 3 PendingConflictMode = PendingConflictModeT() """_""" __pdoc__[ "PendingConflictMode" ] = f"""Conflict mode for pending signals. ```python {prettify(PendingConflictMode)} ``` What should happen if an executable signal occurs during a pending signal? Attributes: KeepIgnore: Keep the pending signal and cancel the user-defined signal. KeepExecute: Keep the pending signal and execute the user-defined signal. CancelIgnore: Cancel the pending signal and ignore the user-defined signal. CancelExecute: Cancel the pending signal and execute the user-defined signal. """ class AccumulationModeT(tp.NamedTuple): Disabled: int = 0 Both: int = 1 AddOnly: int = 2 RemoveOnly: int = 3 AccumulationMode = AccumulationModeT() """_""" __pdoc__[ "AccumulationMode" ] = f"""Accumulation mode. ```python {prettify(AccumulationMode)} ``` Accumulation allows gradually increasing and decreasing positions by a size. Attributes: Disabled: Disable accumulation. Can also be provided as False. Both: Allow both adding to and removing from the position. Can also be provided as True. AddOnly: Allow accumulation to only add to the position. RemoveOnly: Allow accumulation to only remove from the position. !!! note Accumulation acts differently for exits and opposite entries: exits reduce the current position but won't enter the opposite one, while opposite entries reduce the position by the same amount, but as soon as this position is closed, they begin to increase the opposite position. The behavior for opposite entries can be changed by `OppositeEntryMode` and for stop orders by `StopExitType`. """ class ConflictModeT(tp.NamedTuple): Ignore: int = 0 Entry: int = 1 Exit: int = 2 Adjacent: int = 3 Opposite: int = 4 ConflictMode = ConflictModeT() """_""" __pdoc__[ "ConflictMode" ] = f"""Conflict mode. ```python {prettify(ConflictMode)} ``` What should happen if both an entry signal and an exit signal occur simultaneously? Attributes: Ignore: Ignore both signals. Entry: Execute the entry signal. Exit: Execute the exit signal. Adjacent: Execute the signal adjacent to the current position. Takes effect only when in position, otherwise ignores. Opposite: Execute the signal opposite to the current position. Takes effect only when in position, otherwise ignores. """ class DirectionConflictModeT(tp.NamedTuple): Ignore: int = 0 Long: int = 1 Short: int = 2 Adjacent: int = 3 Opposite: int = 4 DirectionConflictMode = DirectionConflictModeT() """_""" __pdoc__[ "DirectionConflictMode" ] = f"""Direction conflict mode. ```python {prettify(DirectionConflictMode)} ``` What should happen if both a long entry signal and a short entry signals occur simultaneously? Attributes: Ignore: Ignore both entry signals. Long: Execute the long entry signal. Short: Execute the short entry signal. Adjacent: Execute the adjacent entry signal. Takes effect only when in position, otherwise ignores. Opposite: Execute the opposite entry signal. Takes effect only when in position, otherwise ignores. """ class OppositeEntryModeT(tp.NamedTuple): Ignore: int = 0 Close: int = 1 CloseReduce: int = 2 Reverse: int = 3 ReverseReduce: int = 4 OppositeEntryMode = OppositeEntryModeT() """_""" __pdoc__[ "OppositeEntryMode" ] = f"""Opposite entry mode. ```python {prettify(OppositeEntryMode)} ``` What should happen if an entry signal of opposite direction occurs before an exit signal? Attributes: Ignore: Ignore the opposite entry signal. Close: Close the current position. CloseReduce: Close the current position or reduce it if accumulation is enabled. Reverse: Reverse the current position. ReverseReduce: Reverse the current position or reduce it if accumulation is enabled. """ class DeltaFormatT(tp.NamedTuple): Absolute: int = 0 Percent: int = 1 Percent100: int = 2 Target: int = 3 DeltaFormat = DeltaFormatT() """_""" __pdoc__[ "DeltaFormat" ] = f"""Delta format. ```python {prettify(DeltaFormat)} ``` In which format a delta is provided? Attributes: Absolute: Absolute format where 0.1 is the absolute difference between the initial and target value Percent: Percentage format where 0.1 is 10% applied to the initial value to get the target value Percent100: Percentage format where 0.1 is 0.1% applied to the initial value to get the target value Target: Target format where 0.1 is the target value """ class TimeDeltaFormatT(tp.NamedTuple): Rows: int = 0 Index: int = 1 TimeDeltaFormat = TimeDeltaFormatT() """_""" __pdoc__[ "TimeDeltaFormat" ] = f"""Time delta format. ```python {prettify(TimeDeltaFormat)} ``` In which format a time delta is provided? Attributes: Rows: Row format where 1 means one row (simulation step) has passed. Doesn't require the index to be provided. Index: Index format where 1 means one value in index has passed. If index is datetime-like, 1 means one nanosecond. Requires the index to be provided. """ class StopLadderModeT(tp.NamedTuple): Disabled: int = 0 Uniform: int = 1 Weighted: int = 2 AdaptUniform: int = 3 AdaptWeighted: int = 4 Dynamic: int = 5 StopLadderMode = StopLadderModeT() """_""" __pdoc__[ "StopLadderMode" ] = f"""Stop ladder mode. ```python {prettify(StopLadderMode)} ``` Attributes: Disabled: Disable the stop ladder. Can also be provided as False. Uniform: Enable the stop ladder with a uniform exit size. Can also be provided as True. Weighted: Enable the stop ladder with a stop-weighted exit size. AdaptUniform: Enable the stop ladder with a uniform exit size that adapts to the current position. AdaptWeighted: Enable the stop ladder with a stop-weighted exit size that adapts to the current position. Dynamic: Enable the stop ladder but do not use stop arrays. !!! note When disabled, make sure that stop arrays broadcast against the target shape. When enabled, make sure that rows in stop arrays represent steps in the ladder. """ class StopEntryPriceT(tp.NamedTuple): ValPrice: int = -1 Open: int = -2 Price: int = -3 FillPrice: int = -4 Close: int = -5 StopEntryPrice = StopEntryPriceT() """_""" __pdoc__[ "StopEntryPrice" ] = f"""Stop entry price. ```python {prettify(StopEntryPrice)} ``` Which price to use as an initial stop price? Attributes: ValPrice: Asset valuation price. Open: Opening price. Price: Order price. FillPrice: Filled order price (that is, slippage is already applied). Close: Closing price. !!! note Each flag is negative, thus if a positive value is provided, it's used directly as price. """ class StopExitPriceT(tp.NamedTuple): Stop: int = -1 HardStop: int = -2 Close: int = -3 StopExitPrice = StopExitPriceT() """_""" __pdoc__[ "StopExitPrice" ] = f"""Stop exit price. ```python {prettify(StopExitPrice)} ``` Which price to use when exiting a position upon a stop signal? Attributes: Stop: Stop price. If the target price is first hit by the opening price, the opening price is used. HardStop: Hard stop price. The stop price is used regardless of whether the target price is first hit by the opening price. Close: Closing price. !!! note Each flag is negative, thus if a positive value is provided, it's used directly as price. """ class StopExitTypeT(tp.NamedTuple): Close: int = 0 CloseReduce: int = 1 Reverse: int = 2 ReverseReduce: int = 3 StopExitType = StopExitTypeT() """_""" __pdoc__[ "StopExitType" ] = f"""Stop exit type. ```python {prettify(StopExitType)} ``` How to exit the current position upon a stop signal? Attributes: Close: Close the current position. CloseReduce: Close the current position or reduce it if accumulation is enabled. Reverse: Reverse the current position. ReverseReduce: Reverse the current position or reduce it if accumulation is enabled. """ class StopUpdateModeT(tp.NamedTuple): Keep: int = 0 Override: int = 1 OverrideNaN: int = 2 StopUpdateMode = StopUpdateModeT() """_""" __pdoc__[ "StopUpdateMode" ] = f"""Stop update mode. ```python {prettify(StopUpdateMode)} ``` What to do with the old stop upon a new entry/accumulation? Attributes: Keep: Keep the old stop. Override: Override the old stop, but only if the new stop is not NaN. OverrideNaN: Override the old stop, even if the new stop is NaN. """ class SizeTypeT(tp.NamedTuple): Amount: int = 0 Value: int = 1 Percent: int = 2 Percent100: int = 3 ValuePercent: int = 4 ValuePercent100: int = 5 TargetAmount: int = 6 TargetValue: int = 7 TargetPercent: int = 8 TargetPercent100: int = 9 SizeType = SizeTypeT() """_""" __pdoc__[ "SizeType" ] = f"""Size type. ```python {prettify(SizeType)} ``` Attributes: Amount: Amount of assets to trade. Value: Asset value to trade. Gets converted into `SizeType.Amount` using `ExecState.val_price`. Percent: Percentage of available resources to use in either direction (not to be confused with the percentage of position value!) where 0.01 means 1% * When long buying, the percentage of `ExecState.free_cash` * When long selling, the percentage of `ExecState.position` * When short selling, the percentage of `ExecState.free_cash` * When short buying, the percentage of `ExecState.free_cash`, `ExecState.debt`, and `ExecState.locked_cash` * When reversing, the percentage is getting applied on the final position Percent100: `SizeType.Percent` where 1.0 means 1%. ValuePercent: Percentage of total value. Uses `ExecState.value` to get the current total value. Gets converted into `SizeType.Value`. ValuePercent100: `SizeType.ValuePercent` where 1.0 means 1%. TargetAmount: Target amount of assets to hold (= target position). Uses `ExecState.position` to get the current position. Gets converted into `SizeType.Amount`. TargetValue: Target asset value. Uses `ExecState.val_price` to get the current asset value. Gets converted into `SizeType.TargetAmount`. TargetPercent: Target percentage of total value. Uses `ExecState.value_now` to get the current total value. Gets converted into `SizeType.TargetValue`. TargetPercent100: `SizeType.TargetPercent` where 1.0 means 1%. """ class DirectionT(tp.NamedTuple): LongOnly: int = 0 ShortOnly: int = 1 Both: int = 2 Direction = DirectionT() """_""" __pdoc__[ "Direction" ] = f"""Position direction. ```python {prettify(Direction)} ``` Attributes: LongOnly: Only long positions. ShortOnly: Only short positions. Both: Both long and short positions. """ class LeverageModeT(tp.NamedTuple): Lazy: int = 0 Eager: int = 1 LeverageMode = LeverageModeT() """_""" __pdoc__[ "LeverageMode" ] = f"""Leverage mode. ```python {prettify(LeverageMode)} ``` Attributes: Lazy: Applies leverage only if free cash has been exhausted. Eager: Applies leverage to each order. """ class PriceAreaVioModeT(tp.NamedTuple): Ignore: int = 0 Cap: int = 1 Error: int = 2 PriceAreaVioMode = PriceAreaVioModeT() """_""" __pdoc__[ "PriceAreaVioMode" ] = f"""Price are violation mode. ```python {prettify(PriceAreaVioMode)} ``` Attributes: Ignore: Ignore any violation. Cap: Cap price to prevent violation. Error: Throw an error upon violation. """ class OrderStatusT(tp.NamedTuple): Filled: int = 0 Ignored: int = 1 Rejected: int = 2 OrderStatus = OrderStatusT() """_""" __pdoc__[ "OrderStatus" ] = f"""Order status. ```python {prettify(OrderStatus)} ``` Attributes: Filled: Order has been filled. Ignored: Order has been ignored. Rejected: Order has been rejected. """ class OrderStatusInfoT(tp.NamedTuple): SizeNaN: int = 0 PriceNaN: int = 1 ValPriceNaN: int = 2 ValueNaN: int = 3 ValueZeroNeg: int = 4 SizeZero: int = 5 NoCash: int = 6 NoOpenPosition: int = 7 MaxSizeExceeded: int = 8 RandomEvent: int = 9 CantCoverFees: int = 10 MinSizeNotReached: int = 11 PartialFill: int = 12 OrderStatusInfo = OrderStatusInfoT() """_""" __pdoc__[ "OrderStatusInfo" ] = f"""Order status information. ```python {prettify(OrderStatusInfo)} ``` """ status_info_desc = [ "Size is NaN", "Price is NaN", "Asset valuation price is NaN", "Asset/group value is NaN", "Asset/group value is zero or negative", "Size is zero", "Not enough cash", "No open position to reduce/close", "Size is greater than maximum allowed", "Random event happened", "Not enough cash to cover fees", "Final size is less than minimum allowed", "Final size is less than requested", ] """_""" __pdoc__[ "status_info_desc" ] = f"""Order status description. ```python {prettify(status_info_desc)} ``` """ class OrderSideT(tp.NamedTuple): Buy: int = 0 Sell: int = 1 OrderSide = OrderSideT() """_""" __pdoc__[ "OrderSide" ] = f"""Order side. ```python {prettify(OrderSide)} ``` """ class OrderTypeT(tp.NamedTuple): Market: int = 0 Limit: int = 1 OrderType = OrderTypeT() """_""" __pdoc__[ "OrderType" ] = f"""Order type. ```python {prettify(OrderType)} ``` """ class LimitOrderPriceT(tp.NamedTuple): Limit: int = -1 HardLimit: int = -2 Close: int = -3 LimitOrderPrice = LimitOrderPriceT() """_""" __pdoc__[ "LimitOrderPrice" ] = f"""Limit order price. ```python {prettify(LimitOrderPrice)} ``` Which price to use when executing a limit order? Attributes: Limit: Limit price. If the target price is first hit by the opening price, the opening price is used. HardLimit: Hard limit price. The stop price is used regardless of whether the target price is first hit by the opening price. Close: Closing price. !!! note Each flag is negative, thus if a positive value is provided, it's used directly as price. """ class TradeDirectionT(tp.NamedTuple): Long: int = 0 Short: int = 1 TradeDirection = TradeDirectionT() """_""" __pdoc__[ "TradeDirection" ] = f"""Trade direction. ```python {prettify(TradeDirection)} ``` """ class TradeStatusT(tp.NamedTuple): Open: int = 0 Closed: int = 1 TradeStatus = TradeStatusT() """_""" __pdoc__[ "TradeStatus" ] = f"""Event status. ```python {prettify(TradeStatus)} ``` """ class TradesTypeT(tp.NamedTuple): Trades: int = 0 EntryTrades: int = 1 ExitTrades: int = 2 Positions: int = 3 TradesType = TradesTypeT() """_""" __pdoc__[ "TradesType" ] = f"""Trades type. ```python {prettify(TradesType)} ``` """ class OrderPriceStatusT(tp.NamedTuple): OK: int = 0 AboveHigh: int = 1 BelowLow: int = 2 Unknown: int = 3 OrderPriceStatus = OrderPriceStatusT() """_""" __pdoc__[ "OrderPriceStatus" ] = f"""Order price error. ```python {prettify(OrderPriceStatus)} ``` Attributes: OK: Order price is within OHLC. AboveHigh: Order price is above high. BelowLow: Order price is below low. Unknown: High and/or low are unknown. """ class PositionFeatureT(tp.NamedTuple): EntryPrice: int = 0 ExitPrice: int = 1 PositionFeature = PositionFeatureT() """_""" __pdoc__[ "PositionFeature" ] = f"""Position feature. ```python {prettify(PositionFeature)} ``` """ # ############# Named tuples ############# # class PriceArea(tp.NamedTuple): open: float high: float low: float close: float __pdoc__[ "PriceArea" ] = """Price area defined by four boundaries. Used together with `PriceAreaVioMode`.""" __pdoc__["PriceArea.open"] = "Opening price of the time step." __pdoc__[ "PriceArea.high" ] = """Highest price of the time step. Violation takes place when adjusted price goes above this value. """ __pdoc__[ "PriceArea.low" ] = """Lowest price of the time step. Violation takes place when adjusted price goes below this value. """ __pdoc__[ "PriceArea.close" ] = """Closing price of the time step. Violation takes place when adjusted price goes beyond this value. """ NoPriceArea = PriceArea(open=np.nan, high=np.nan, low=np.nan, close=np.nan) """_""" __pdoc__["NoPriceArea"] = "No price area." class AccountState(tp.NamedTuple): cash: float position: float debt: float locked_cash: float free_cash: float __pdoc__["AccountState"] = "State of the account." __pdoc__[ "AccountState.cash" ] = """Cash. Per group with cash sharing, otherwise per column.""" __pdoc__[ "AccountState.position" ] = """Position. Per column.""" __pdoc__[ "AccountState.debt" ] = """Debt. Per column.""" __pdoc__[ "AccountState.locked_cash" ] = """Locked cash. Per column.""" __pdoc__[ "AccountState.free_cash" ] = """Free cash. Per group with cash sharing, otherwise per column.""" class ExecState(tp.NamedTuple): cash: float position: float debt: float locked_cash: float free_cash: float val_price: float value: float __pdoc__["ExecState"] = "State before or after order execution." __pdoc__["ExecState.cash"] = "See `AccountState.cash`." __pdoc__["ExecState.position"] = "See `AccountState.position`." __pdoc__["ExecState.debt"] = "See `AccountState.debt`." __pdoc__["ExecState.locked_cash"] = "See `AccountState.locked_cash`." __pdoc__["ExecState.free_cash"] = "See `AccountState.free_cash`." __pdoc__["ExecState.val_price"] = "Valuation price in the current column." __pdoc__["ExecState.value"] = "Value in the current column (or group with cash sharing)." class SimulationOutput(tp.NamedTuple): order_records: tp.RecordArray2d log_records: tp.RecordArray2d cash_deposits: tp.Array2d cash_earnings: tp.Array2d call_seq: tp.Optional[tp.Array2d] in_outputs: tp.Optional[tp.NamedTuple] sim_start: tp.Optional[tp.Array1d] sim_end: tp.Optional[tp.Array1d] __pdoc__["SimulationOutput"] = "A named tuple representing the output of a simulation." __pdoc__["SimulationOutput.order_records"] = "Order records (flattened)." __pdoc__["SimulationOutput.log_records"] = "Log records (flattened)." __pdoc__[ "SimulationOutput.cash_deposits" ] = """Cash deposited/withdrawn at each timestamp. If not tracked, becomes zero of shape `(1, 1)`.""" __pdoc__[ "SimulationOutput.cash_earnings" ] = """Cash earnings added/removed at each timestamp. If not tracked, becomes zero of shape `(1, 1)`.""" __pdoc__[ "SimulationOutput.call_seq" ] = """Call sequence. If not tracked, becomes None.""" __pdoc__[ "SimulationOutput.in_outputs" ] = """Named tuple with in-output objects. If not tracked, becomes None.""" __pdoc__[ "SimulationOutput.sim_start" ] = """Start of the simulation per column. Use `vectorbtpro.generic.nb.sim_range.prepare_ungrouped_sim_range_nb` to ungroup the array. If not tracked, becomes None.""" __pdoc__[ "SimulationOutput.sim_end" ] = """End of the simulation per column. Use `vectorbtpro.generic.nb.sim_range.prepare_ungrouped_sim_range_nb` to ungroup the array. If not tracked, becomes None.""" class SimulationContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d __pdoc__[ "SimulationContext" ] = """A named tuple representing the context of a simulation. Contains general information available to all other contexts. Passed to `pre_sim_func_nb` and `post_sim_func_nb`.""" __pdoc__[ "SimulationContext.target_shape" ] = """Target shape of the simulation. A tuple with exactly two elements: the number of rows and columns. Example: One day of minute data for three assets would yield a `target_shape` of `(1440, 3)`, where the first axis are rows (minutes) and the second axis are columns (assets). """ __pdoc__[ "SimulationContext.group_lens" ] = """Number of columns in each group. Even if columns are not grouped, `group_lens` contains ones - one column per group. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. Example: In pairs trading, `group_lens` would be `np.array([2])`, while three independent columns would be represented by `group_lens` of `np.array([1, 1, 1])`. """ __pdoc__["SimulationContext.cash_sharing"] = "Whether cash sharing is enabled." __pdoc__[ "SimulationContext.call_seq" ] = """Default sequence of calls per segment. Controls the sequence in which `order_func_nb` is executed within each segment. Has shape `SimulationContext.target_shape` and each value must exist in the range `[0, group_len)`. Can also be None if not provided. !!! note To use `sort_call_seq_1d_nb`, must be generated using `CallSeqType.Default`. To change the call sequence dynamically, better change `SegmentContext.call_seq_now` in-place. Example: The default call sequence for three data points and two groups with three columns each: ```python np.array([ [0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2] ]) ``` """ __pdoc__[ "SimulationContext.init_cash" ] = """Initial capital per column (or per group with cash sharing). Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_1d_pc_nb`. Must broadcast to shape `(group_lens.shape[0],)` with cash sharing, otherwise `(target_shape[1],)`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. Example: Consider three columns, each having $100 of starting capital. If we built one group of two columns and one group of one column, the `init_cash` would be `np.array([200, 100])` with cash sharing and `np.array([100, 100, 100])` without cash sharing. """ __pdoc__[ "SimulationContext.init_position" ] = """Initial position per column. Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_1d_pc_nb`. Must broadcast to shape `(target_shape[1],)`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.init_price" ] = """Initial position price per column. Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_1d_pc_nb`. Must broadcast to shape `(target_shape[1],)`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.cash_deposits" ] = """Cash to be deposited/withdrawn per column (or per group with cash sharing). Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`. Must broadcast to shape `(target_shape[0], group_lens.shape[0])`. Cash is deposited/withdrawn right after `pre_segment_func_nb`. You can modify this array in `pre_segment_func_nb`. !!! note To modify the array in place, make sure to build an array of the full shape. """ __pdoc__[ "SimulationContext.cash_earnings" ] = """Earnings to be added per column. Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`. Must broadcast to shape `SimulationContext.target_shape`. Earnings are added right before `post_segment_func_nb` and are already included in the value of each group. You can modify this array in `pre_segment_func_nb` or `post_order_func_nb`. !!! note To modify the array in place, make sure to build an array of the full shape. """ __pdoc__[ "SimulationContext.segment_mask" ] = """Mask of whether a particular segment should be executed. A segment is simply a sequence of `order_func_nb` calls under the same group and row. If a segment is inactive, any callback function inside of it will not be executed. You can still execute the segment's pre- and postprocessing function by enabling `SimulationContext.call_pre_segment` and `SimulationContext.call_post_segment` respectively. Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`. Must broadcast to shape `(target_shape[0], group_lens.shape[0])`. !!! note To modify the array in place, make sure to build an array of the full shape. Example: Consider two groups with two columns each and the following activity mask: ```python np.array([[ True, False], [False, True]]) ``` The first group is only executed in the first row and the second group is only executed in the second row. """ __pdoc__[ "SimulationContext.call_pre_segment" ] = """Whether to call `pre_segment_func_nb` regardless of `SimulationContext.segment_mask`.""" __pdoc__[ "SimulationContext.call_post_segment" ] = """Whether to call `post_segment_func_nb` regardless of `SimulationContext.segment_mask`. Allows, for example, to write user-defined arrays such as returns at the end of each segment.""" __pdoc__[ "SimulationContext.index" ] = """Index in integer (nanosecond) format. If datetime-like, assumed to have the UTC timezone. Preset simulation methods will automatically format any index as UTC without actually converting it to UTC, that is, `12:00 +02:00` will become `12:00 +00:00` to avoid timezone conversion issues.""" __pdoc__["SimulationContext.freq"] = """Frequency of index in integer (nanosecond) format.""" __pdoc__[ "SimulationContext.open" ] = """Opening price. Replaces `Order.price` in case it's `-np.inf`. Similar behavior to that of `SimulationContext.close`.""" __pdoc__[ "SimulationContext.high" ] = """Highest price. Similar behavior to that of `SimulationContext.close`.""" __pdoc__[ "SimulationContext.low" ] = """Lowest price. Similar behavior to that of `SimulationContext.close`.""" __pdoc__[ "SimulationContext.close" ] = """Closing price at each time step. Replaces `Order.price` in case it's `np.inf`. Acts as a boundary - see `PriceArea.close`. Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`. Must broadcast to shape `SimulationContext.target_shape`. !!! note To modify the array in place, make sure to build an array of the full shape. """ __pdoc__[ "SimulationContext.bm_close" ] = """Benchmark closing price at each time step. Must broadcast to shape `SimulationContext.target_shape`.""" __pdoc__[ "SimulationContext.ffill_val_price" ] = """Whether to track valuation price only if it's known. Otherwise, unknown `SimulationContext.close` will lead to NaN in valuation price at the next timestamp.""" __pdoc__[ "SimulationContext.update_value" ] = """Whether to update group value after each filled order. Otherwise, stays the same for all columns in the group (the value is calculated only once, before executing any order). The change is marginal and mostly driven by transaction costs and slippage.""" __pdoc__[ "SimulationContext.fill_pos_info" ] = """Whether to fill position record. Disable this to make simulation faster for simple use cases.""" __pdoc__[ "SimulationContext.track_value" ] = """Whether to track value metrics such as the current valuation price, value, and return. If False, 'SimulationContext.last_val_price', 'SimulationContext.last_value', and 'SimulationContext.last_return' will stay NaN and the statistics of any open position won't be updated. You won't be able to use `SizeType.Value`, `SizeType.TargetValue`, and `SizeType.TargetPercent`. Disable this to make simulation faster for simple use cases.""" __pdoc__[ "SimulationContext.order_records" ] = """Order records per column. It's a 2-dimensional array with records of type `order_dt`. The array is initialized with empty records first (they contain random data), and then gradually filled with order data. The number of empty records depends upon `max_order_records`, but usually it matches the number of rows, meaning there is maximal one order record per element. `max_order_records` can be chosen lower if not every `order_func_nb` leads to a filled order, to save memory. It can also be chosen higher if more than one order per element is expected. You can use `SimulationContext.order_counts` to get the number of filled orders in each column. To get all order records filled up to this point in a column, do `order_records[:order_counts[col], col]`. Example: Before filling, each order record looks like this: ```python np.array([(-8070450532247928832, -8070450532247928832, 4, 0., 0., 0., 5764616306889786413)] ``` After filling, it becomes like this: ```python np.array([(0, 0, 1, 50., 1., 0., 1)] ``` """ __pdoc__[ "SimulationContext.order_counts" ] = """Number of filled order records in each column. Points to `SimulationContext.order_records` and has shape `(target_shape[1],)`. Example: `order_counts` of `np.array([2, 100, 0])` means the latest filled order is `order_records[1, 0]` in the first column, `order_records[99, 1]` in the second column, and no orders have been filled yet in the third column (`order_records[0, 2]` is empty). !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.log_records" ] = """Log records per column. Similar to `SimulationContext.order_records` but of type `log_dt` and count `SimulationContext.log_counts`.""" __pdoc__[ "SimulationContext.log_counts" ] = """Number of filled log records in each column. Similar to `SimulationContext.log_counts` but for log records. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.in_outputs" ] = """Named tuple with in-output objects. Can contain objects of arbitrary shape and type. Will be returned as part of `SimulationOutput`.""" __pdoc__[ "SimulationContext.last_cash" ] = """Latest cash per column (or per group with cash sharing). At the very first timestamp, contains initial capital. Gets updated right after `order_func_nb`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_position" ] = """Latest position per column. At the very first timestamp, contains initial position. Has shape `(target_shape[1],)`. Gets updated right after `order_func_nb`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_debt" ] = """Latest debt from leverage or shorting per column. Has shape `(target_shape[1],)`. Gets updated right after `order_func_nb`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_locked_cash" ] = """Latest locked cash from leverage or shorting per column. Has shape `(target_shape[1],)`. Gets updated right after `order_func_nb`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_free_cash" ] = """Latest free cash per column (or per group with cash sharing). Free cash never goes above the initial level, because an operation always costs money. Has shape `(target_shape[1],)`. Gets updated right after `order_func_nb`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_val_price" ] = """Latest valuation price per column. Has shape `(target_shape[1],)`. Enables `SizeType.Value`, `SizeType.TargetValue`, and `SizeType.TargetPercent`. Gets multiplied by the current position to get the value of the column (see `SimulationContext.last_value`). Gets updated right before `pre_segment_func_nb` using `SimulationContext.open`. Then, gets updated right after `pre_segment_func_nb` - you can use `pre_segment_func_nb` to override `last_val_price` in-place, such that `order_func_nb` can use the new group value. If `SimulationContext.update_value`, gets also updated right after `order_func_nb` using filled order price as the latest known price. Finally, gets updated right before `post_segment_func_nb` using `SimulationContext.close`. If `SimulationContext.ffill_val_price`, gets updated only if the value is not NaN. For example, close of `[1, 2, np.nan, np.nan, 5]` yields valuation price of `[1, 2, 2, 2, 5]`. !!! note You are not allowed to use `-np.inf` or `np.inf` - only finite values. If `SimulationContext.open` is NaN in the first row, the `last_val_price` is also NaN. Example: Consider 10 units in column 1 and 20 units in column 2. The current opening price of them is $40 and $50 respectively, which is also the default valuation price in the current row, available as `last_val_price` in `pre_segment_func_nb`. If both columns are in the same group with cash sharing, the group is valued at $1400 before any `order_func_nb` is called, and can be later accessed via `OrderContext.value_now`. """ __pdoc__[ "SimulationContext.last_value" ] = """Latest value per column (or per group with cash sharing). Calculated by multiplying the valuation price by the current position and adding the cash. The value in each column in a group with cash sharing is summed to get the value of the entire group. Gets updated right before `pre_segment_func_nb`. Then, gets updated right after `pre_segment_func_nb`. If `SimulationContext.update_value`, gets also updated right after `order_func_nb` using filled order price as the latest known price (the difference will be minimal, only affected by costs). Finally, gets updated right before `post_segment_func_nb`. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_return" ] = """Latest return per column (or per group with cash sharing). Has the same shape as `SimulationContext.last_value`. Calculated by comparing the current `SimulationContext.last_value` to the last one of the previous row. Gets updated each time `SimulationContext.last_value` is updated. !!! note Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`. """ __pdoc__[ "SimulationContext.last_pos_info" ] = """Latest position record in each column. It's a 1-dimensional array with records of type `trade_dt`. Has shape `(target_shape[1],)`. If `SimulationContext.init_position` is not zero in a column, that column's position record is automatically filled before the simulation with `entry_price` set to `SimulationContext.init_price` and `entry_idx` of -1. The fields `entry_price` and `exit_price` are average entry and exit price respectively. The average exit price does **not** contain open statistics, as opposed to `vectorbtpro.portfolio.trades.Positions`. On the other hand, fields `pnl` and `return` contain statistics as if the position has been closed and are re-calculated using `SimulationContext.last_val_price` right before and after `pre_segment_func_nb`, right after `order_func_nb`, and right before `post_segment_func_nb`. !!! note In an open position record, the field `exit_price` doesn't reflect the latest valuation price, but keeps the average price at which the position has been reduced. """ __pdoc__[ "SimulationContext.sim_start" ] = """Start of the simulation per column or group (also without cash sharing). Changing in-place won't apply to the current simulation.""" __pdoc__[ "SimulationContext.sim_start" ] = """End of the simulation per column or group (also without cash sharing). Changing in-place will apply to the current simulation if it's lower than the initial value.""" class GroupContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int __pdoc__[ "GroupContext" ] = """A named tuple representing the context of a group. A group is a set of nearby columns that are somehow related (for example, by sharing the same capital). In each row, the columns under the same group are bound to the same segment. Contains all fields from `SimulationContext` plus fields describing the current group. Passed to `pre_group_func_nb` and `post_group_func_nb`. Example: Consider a group of three columns, a group of two columns, and one more column: | group | group_len | from_col | to_col | | ----- | --------- | -------- | ------ | | 0 | 3 | 0 | 3 | | 1 | 2 | 3 | 5 | | 2 | 1 | 5 | 6 | """ for field in GroupContext._fields: if field in SimulationContext._fields: __pdoc__["GroupContext." + field] = f"See `SimulationContext.{field}`." __pdoc__[ "GroupContext.group" ] = """Index of the current group. Has range `[0, group_lens.shape[0])`. """ __pdoc__[ "GroupContext.group_len" ] = """Number of columns in the current group. Scalar value. Same as `group_lens[group]`. """ __pdoc__[ "GroupContext.from_col" ] = """Index of the first column in the current group. Has range `[0, target_shape[1])`. """ __pdoc__[ "GroupContext.to_col" ] = """Index of the last column in the current group plus one. Has range `[1, target_shape[1] + 1)`. If columns are not grouped, equals to `from_col + 1`. !!! warning In the last group, `to_col` points at a column that doesn't exist. """ class RowContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d i: int __pdoc__[ "RowContext" ] = """A named tuple representing the context of a row. A row is a time step in which segments are executed. Contains all fields from `SimulationContext` plus fields describing the current row. Passed to `pre_row_func_nb` and `post_row_func_nb`. """ for field in RowContext._fields: if field in SimulationContext._fields: __pdoc__["RowContext." + field] = f"See `SimulationContext.{field}`." __pdoc__[ "RowContext.i" ] = """Index of the current row. Has range `[0, target_shape[0])`. """ class SegmentContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int call_seq_now: tp.Optional[tp.Array1d] __pdoc__[ "SegmentContext" ] = """A named tuple representing the context of a segment. A segment is an intersection between groups and rows. It's an entity that defines how and in which order elements within the same group and row are processed. Contains all fields from `SimulationContext`, `GroupContext`, and `RowContext`, plus fields describing the current segment. Passed to `pre_segment_func_nb` and `post_segment_func_nb`. """ for field in SegmentContext._fields: if field in SimulationContext._fields: __pdoc__["SegmentContext." + field] = f"See `SimulationContext.{field}`." elif field in GroupContext._fields: __pdoc__["SegmentContext." + field] = f"See `GroupContext.{field}`." elif field in RowContext._fields: __pdoc__["SegmentContext." + field] = f"See `RowContext.{field}`." __pdoc__[ "SegmentContext.call_seq_now" ] = """Sequence of calls within the current segment. Has shape `(group_len,)`. Each value in this sequence must indicate the position of column in the group to call next. Processing goes always from left to right. You can use `pre_segment_func_nb` to override `call_seq_now`. Example: `[2, 0, 1]` would first call column 2, then 0, and finally 1. """ class OrderContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int call_seq_now: tp.Optional[tp.Array1d] col: int call_idx: int cash_now: float position_now: float debt_now: float locked_cash_now: float free_cash_now: float val_price_now: float value_now: float return_now: float pos_info_now: tp.Record __pdoc__[ "OrderContext" ] = """A named tuple representing the context of an order. Contains all fields from `SegmentContext` plus fields describing the current state. Passed to `order_func_nb`. """ for field in OrderContext._fields: if field in SimulationContext._fields: __pdoc__["OrderContext." + field] = f"See `SimulationContext.{field}`." elif field in GroupContext._fields: __pdoc__["OrderContext." + field] = f"See `GroupContext.{field}`." elif field in RowContext._fields: __pdoc__["OrderContext." + field] = f"See `RowContext.{field}`." elif field in SegmentContext._fields: __pdoc__["OrderContext." + field] = f"See `SegmentContext.{field}`." __pdoc__[ "OrderContext.col" ] = """Current column. Has range `[0, target_shape[1])` and is always within `[from_col, to_col)`. """ __pdoc__[ "OrderContext.call_idx" ] = """Index of the current call in `SegmentContext.call_seq_now`. Has range `[0, group_len)`. """ __pdoc__["OrderContext.cash_now"] = "`SimulationContext.last_cash` for the current column/group." __pdoc__["OrderContext.position_now"] = "`SimulationContext.last_position` for the current column." __pdoc__["OrderContext.debt_now"] = "`SimulationContext.last_debt` for the current column." __pdoc__["OrderContext.locked_cash_now"] = "`SimulationContext.last_locked_cash` for the current column." __pdoc__["OrderContext.free_cash_now"] = "`SimulationContext.last_free_cash` for the current column/group." __pdoc__["OrderContext.val_price_now"] = "`SimulationContext.last_val_price` for the current column." __pdoc__["OrderContext.value_now"] = "`SimulationContext.last_value` for the current column/group." __pdoc__["OrderContext.return_now"] = "`SimulationContext.last_return` for the current column/group." __pdoc__["OrderContext.pos_info_now"] = "`SimulationContext.last_pos_info` for the current column." class PostOrderContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int call_seq_now: tp.Optional[tp.Array1d] col: int call_idx: int cash_before: float position_before: float debt_before: float locked_cash_before: float free_cash_before: float val_price_before: float value_before: float order_result: "OrderResult" cash_now: float position_now: float debt_now: float locked_cash_now: float free_cash_now: float val_price_now: float value_now: float return_now: float pos_info_now: tp.Record __pdoc__[ "PostOrderContext" ] = """A named tuple representing the context after an order has been processed. Contains all fields from `OrderContext` plus fields describing the order result and the previous state. Passed to `post_order_func_nb`. """ for field in PostOrderContext._fields: if field in SimulationContext._fields: __pdoc__["PostOrderContext." + field] = f"See `SimulationContext.{field}`." elif field in GroupContext._fields: __pdoc__["PostOrderContext." + field] = f"See `GroupContext.{field}`." elif field in RowContext._fields: __pdoc__["PostOrderContext." + field] = f"See `RowContext.{field}`." elif field in SegmentContext._fields: __pdoc__["PostOrderContext." + field] = f"See `SegmentContext.{field}`." elif field in OrderContext._fields: __pdoc__["PostOrderContext." + field] = f"See `OrderContext.{field}`." __pdoc__["PostOrderContext.cash_before"] = "`OrderContext.cash_now` before execution." __pdoc__["PostOrderContext.position_before"] = "`OrderContext.position_now` before execution." __pdoc__["PostOrderContext.debt_before"] = "`OrderContext.debt_now` before execution." __pdoc__["PostOrderContext.locked_cash_before"] = "`OrderContext.locked_cash_now` before execution." __pdoc__["PostOrderContext.free_cash_before"] = "`OrderContext.free_cash_now` before execution." __pdoc__["PostOrderContext.val_price_before"] = "`OrderContext.val_price_now` before execution." __pdoc__["PostOrderContext.value_before"] = "`OrderContext.value_now` before execution." __pdoc__[ "PostOrderContext.order_result" ] = """Order result of type `OrderResult`. Can be used to check whether the order has been filled, ignored, or rejected. """ __pdoc__["PostOrderContext.cash_now"] = "`OrderContext.cash_now` after execution." __pdoc__["PostOrderContext.position_now"] = "`OrderContext.position_now` after execution." __pdoc__["PostOrderContext.debt_now"] = "`OrderContext.debt_now` after execution." __pdoc__["PostOrderContext.locked_cash_now"] = "`OrderContext.locked_cash_now` after execution." __pdoc__["PostOrderContext.free_cash_now"] = "`OrderContext.free_cash_now` after execution." __pdoc__[ "PostOrderContext.val_price_now" ] = """`OrderContext.val_price_now` after execution. If `SimulationContext.update_value`, gets replaced with the fill price, as it becomes the most recently known price. Otherwise, stays the same. """ __pdoc__[ "PostOrderContext.value_now" ] = """`OrderContext.value_now` after execution. If `SimulationContext.update_value`, gets updated with the new cash and value of the column. Otherwise, stays the same. """ __pdoc__["PostOrderContext.return_now"] = "`OrderContext.return_now` after execution." __pdoc__["PostOrderContext.pos_info_now"] = "`OrderContext.pos_info_now` after execution." class FlexOrderContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool call_seq: tp.Optional[tp.Array2d] init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d cash_deposits: tp.FlexArray2d cash_earnings: tp.FlexArray2d segment_mask: tp.FlexArray2d call_pre_segment: bool call_post_segment: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d bm_close: tp.FlexArray2d ffill_val_price: bool update_value: bool fill_pos_info: bool track_value: bool order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.RecordArray sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int call_seq_now: None call_idx: int __pdoc__[ "FlexOrderContext" ] = """A named tuple representing the context of a flexible order. Contains all fields from `SegmentContext` plus the current call index. Passed to `flex_order_func_nb`. """ for field in FlexOrderContext._fields: if field in SimulationContext._fields: __pdoc__["FlexOrderContext." + field] = f"See `SimulationContext.{field}`." elif field in GroupContext._fields: __pdoc__["FlexOrderContext." + field] = f"See `GroupContext.{field}`." elif field in RowContext._fields: __pdoc__["FlexOrderContext." + field] = f"See `RowContext.{field}`." elif field in SegmentContext._fields: __pdoc__["FlexOrderContext." + field] = f"See `SegmentContext.{field}`." __pdoc__["FlexOrderContext.call_idx"] = "Index of the current call." class Order(tp.NamedTuple): size: float = np.inf price: float = np.inf size_type: int = SizeType.Amount direction: int = Direction.Both fees: float = 0.0 fixed_fees: float = 0.0 slippage: float = 0.0 min_size: float = np.nan max_size: float = np.nan size_granularity: float = np.nan leverage: float = 1.0 leverage_mode: int = LeverageMode.Lazy reject_prob: float = 0.0 price_area_vio_mode: int = PriceAreaVioMode.Ignore allow_partial: bool = True raise_reject: bool = False log: bool = False __pdoc__[ "Order" ] = """A named tuple representing an order. !!! note Currently, Numba has issues with using defaults when filling named tuples. Use `vectorbtpro.portfolio.nb.core.order_nb` to create an order.""" __pdoc__[ "Order.size" ] = """Size in units. Behavior depends upon `Order.size_type` and `Order.direction`. For any fixed size: * Set to any number to buy/sell some fixed amount or value. * Set to `np.inf` to buy for all cash, or `-np.inf` to sell for all free cash. If `Order.direction` is not `Direction.Both`, `-np.inf` will close the position. * Set to `np.nan` or 0 to skip. For any target size: * Set to any number to buy/sell an amount relative to the current position or value. * Set to 0 to close the current position. * Set to `np.nan` to skip. """ __pdoc__[ "Order.price" ] = """Price per unit. Final price will depend upon slippage. * If `-np.inf`, gets replaced by the current open. * If `np.inf`, gets replaced by the current close. !!! note Make sure to use timestamps that come between (and ideally not including) the current open and close.""" __pdoc__["Order.size_type"] = "See `SizeType`." __pdoc__["Order.direction"] = "See `Direction`." __pdoc__[ "Order.fees" ] = """Fees in percentage of the order value. Negative trading fees like -0.05 mean earning 5% per trade instead of paying a fee. !!! note 0.01 = 1%.""" __pdoc__[ "Order.fixed_fees" ] = """Fixed amount of fees to pay for this order. Similar to `Order.fees`, can be negative.""" __pdoc__[ "Order.slippage" ] = """Slippage in percentage of `Order.price`. Slippage is a penalty applied on the price. !!! note 0.01 = 1%.""" __pdoc__[ "Order.min_size" ] = """Minimum size in both directions. Depends on `Order.size_type`. Lower than that will be rejected.""" __pdoc__[ "Order.max_size" ] = """Maximum size in both directions. Depends on `Order.size_type`. Higher than that will be partly filled.""" __pdoc__[ "Order.size_granularity" ] = """Granularity of the size. For example, granularity of 1.0 makes the quantity to behave like an integer. Placing an order of 12.5 shares (in any direction) will order exactly 12.0 shares. !!! note The filled size remains a floating number.""" __pdoc__["Order.leverage"] = "Leverage." __pdoc__["Order.leverage_mode"] = "See `LeverageMode`." __pdoc__[ "Order.reject_prob" ] = """Probability of rejecting this order to simulate a random rejection event. Not everything goes smoothly in real life. Use random rejections to test your order management for robustness.""" __pdoc__["Order.price_area_vio_mode"] = "See `PriceAreaVioMode`." __pdoc__[ "Order.allow_partial" ] = """Whether to allow partial fill. Otherwise, the order gets rejected. Does not apply when `Order.size` is `np.inf`.""" __pdoc__[ "Order.raise_reject" ] = """Whether to raise exception if order has been rejected. Terminates the simulation.""" __pdoc__[ "Order.log" ] = """Whether to log this order by filling a log record. Remember to increase `max_log_records`.""" NoOrder = Order( size=np.nan, price=np.nan, size_type=-1, direction=-1, fees=np.nan, fixed_fees=np.nan, slippage=np.nan, min_size=np.nan, max_size=np.nan, size_granularity=np.nan, leverage=1.0, leverage_mode=LeverageMode.Lazy, reject_prob=np.nan, price_area_vio_mode=-1, allow_partial=False, raise_reject=False, log=False, ) """_""" __pdoc__["NoOrder"] = "Order that should not be processed." class OrderResult(tp.NamedTuple): size: float price: float fees: float side: int status: int status_info: int __pdoc__["OrderResult"] = "A named tuple representing an order result." __pdoc__["OrderResult.size"] = "Filled size." __pdoc__["OrderResult.price"] = "Filled price per unit, adjusted with slippage." __pdoc__["OrderResult.fees"] = "Total fees paid for this order." __pdoc__["OrderResult.side"] = "See `OrderSide`." __pdoc__["OrderResult.status"] = "See `OrderStatus`." __pdoc__["OrderResult.status_info"] = "See `OrderStatusInfo`." class SignalSegmentContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d track_cash_deposits: bool cash_deposits_out: tp.Array2d track_cash_earnings: bool cash_earnings_out: tp.Array2d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.Array1d last_limit_info: tp.Array1d last_sl_info: tp.Array1d last_tsl_info: tp.Array1d last_tp_info: tp.Array1d last_td_info: tp.Array1d last_dt_info: tp.Array1d sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int __pdoc__[ "SignalSegmentContext" ] = """A named tuple representing the context of a segment in a from-signals simulation. Contains information related to the cascade of the simulation, such as OHLC, but also internal information that is not passed by the user but created at the beginning of the simulation. To make use of other information, such as order size, use templates. Passed to `post_segment_func_nb`.""" for field in SignalSegmentContext._fields: if field in SimulationContext._fields: __pdoc__["SignalSegmentContext." + field] = f"See `SimulationContext.{field}`." for field in SignalSegmentContext._fields: if field in GroupContext._fields: __pdoc__["SignalSegmentContext." + field] = f"See `GroupContext.{field}`." for field in SignalSegmentContext._fields: if field in RowContext._fields: __pdoc__["SignalSegmentContext." + field] = f"See `RowContext.{field}`." __pdoc__[ "SignalSegmentContext.track_cash_deposits" ] = """Whether to track cash deposits. Becomes True if any value in `cash_deposits` is not zero.""" __pdoc__["SignalSegmentContext.cash_deposits_out"] = "See `SimulationOutput.cash_deposits`." __pdoc__[ "SignalSegmentContext.track_cash_earnings" ] = """Whether to track cash earnings. Becomes True if any value in `cash_earnings` is not zero.""" __pdoc__["SignalSegmentContext.cash_earnings_out"] = "See `SimulationOutput.cash_earnings`." __pdoc__["SignalSegmentContext.in_outputs"] = "See `FSInOutputs`." __pdoc__[ "SignalSegmentContext.last_limit_info" ] = """Record of type `limit_info_dt` per column. Accessible via `c.limit_info_dt[field][col]`.""" __pdoc__[ "SignalSegmentContext.last_sl_info" ] = """Record of type `sl_info_dt` per column. Accessible via `c.last_sl_info[field][col]`.""" __pdoc__[ "SignalSegmentContext.last_tsl_info" ] = """Record of type `tsl_info_dt` per column. Accessible via `c.last_tsl_info[field][col]`.""" __pdoc__[ "SignalSegmentContext.last_tp_info" ] = """Record of type `tp_info_dt` per column. Accessible via `c.last_tp_info[field][col]`.""" __pdoc__[ "SignalSegmentContext.last_td_info" ] = """Record of type `time_info_dt` per column. Accessible via `c.last_td_info[field][col]`.""" __pdoc__[ "SignalSegmentContext.last_dt_info" ] = """Record of type `time_info_dt` per column. Accessible via `c.last_dt_info[field][col]`.""" class SignalContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d track_cash_deposits: bool cash_deposits_out: tp.Array2d track_cash_earnings: bool cash_earnings_out: tp.Array2d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.Array1d last_limit_info: tp.Array1d last_sl_info: tp.Array1d last_tsl_info: tp.Array1d last_tp_info: tp.Array1d last_td_info: tp.Array1d last_dt_info: tp.Array1d sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int col: int __pdoc__[ "SignalContext" ] = """A named tuple representing the context of an element in a from-signals simulation. Contains all fields from `SignalSegmentContext` plus the column field. Passed to `signal_func_nb` and `adjust_func_nb`. """ for field in SignalContext._fields: if field in SignalSegmentContext._fields: __pdoc__["SignalContext." + field] = f"See `SignalSegmentContext.{field}`." __pdoc__["SignalContext.col"] = "See `OrderContext.col`." class PostSignalContext(tp.NamedTuple): target_shape: tp.Shape group_lens: tp.GroupLens cash_sharing: bool index: tp.Optional[tp.Array1d] freq: tp.Optional[int] open: tp.FlexArray2d high: tp.FlexArray2d low: tp.FlexArray2d close: tp.FlexArray2d init_cash: tp.FlexArray1d init_position: tp.FlexArray1d init_price: tp.FlexArray1d order_records: tp.RecordArray2d order_counts: tp.Array1d log_records: tp.RecordArray2d log_counts: tp.Array1d track_cash_deposits: bool cash_deposits_out: tp.Array2d track_cash_earnings: bool cash_earnings_out: tp.Array2d in_outputs: tp.Optional[tp.NamedTuple] last_cash: tp.Array1d last_position: tp.Array1d last_debt: tp.Array1d last_locked_cash: tp.Array1d last_free_cash: tp.Array1d last_val_price: tp.Array1d last_value: tp.Array1d last_return: tp.Array1d last_pos_info: tp.Array1d last_limit_info: tp.Array1d last_sl_info: tp.Array1d last_tsl_info: tp.Array1d last_tp_info: tp.Array1d last_td_info: tp.Array1d last_dt_info: tp.Array1d sim_start: tp.Array1d sim_end: tp.Array1d group: int group_len: int from_col: int to_col: int i: int col: int cash_before: float position_before: float debt_before: float locked_cash_before: float free_cash_before: float val_price_before: float value_before: float order_result: "OrderResult" __pdoc__[ "PostSignalContext" ] = """A named tuple representing the context after an order has been processed in a from-signals simulation. Contains all fields from `SignalContext` plus the previous balances and order result. Passed to `post_signal_func_nb`. """ for field in PostSignalContext._fields: if field in SignalContext._fields: __pdoc__["PostSignalContext." + field] = f"See `SignalContext.{field}`." __pdoc__["PostSignalContext.cash_before"] = "`ExecState.cash` before execution." __pdoc__["PostSignalContext.position_before"] = "`ExecState.position` before execution." __pdoc__["PostSignalContext.debt_before"] = "`ExecState.debt` before execution." __pdoc__["PostSignalContext.locked_cash_before"] = "`ExecState.free_cash` before execution." __pdoc__["PostSignalContext.free_cash_before"] = "`ExecState.val_price` before execution." __pdoc__["PostSignalContext.val_price_before"] = "`ExecState.value` before execution." __pdoc__["PostSignalContext.order_result"] = "`PostOrderContext.order_result`." # ############# In-outputs ############# # class FOInOutputs(tp.NamedTuple): cash: tp.Array2d position: tp.Array2d debt: tp.Array2d locked_cash: tp.Array2d free_cash: tp.Array2d value: tp.Array2d returns: tp.Array2d __pdoc__["FOInOutputs"] = "A named tuple representing the in-outputs for simulation based on orders." __pdoc__[ "FOInOutputs.cash" ] = """See `AccountState.cash`. Follows groups if cash sharing is enabled, otherwise columns. Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`.""" __pdoc__[ "FOInOutputs.position" ] = """See `AccountState.position`. Follows columns. Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`.""" __pdoc__[ "FOInOutputs.debt" ] = """See `AccountState.debt`. Follows columns. Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`.""" __pdoc__[ "FOInOutputs.locked_cash" ] = """See `AccountState.locked_cash`. Follows columns. Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`.""" __pdoc__[ "FOInOutputs.free_cash" ] = """See `AccountState.free_cash`. Follows groups if cash sharing is enabled, otherwise columns. Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`.""" __pdoc__[ "FOInOutputs.value" ] = """Value. Follows groups if cash sharing is enabled, otherwise columns. Gets filled if `fill_value` is True, otherwise has the shape `(0, 0)`.""" __pdoc__[ "FOInOutputs.returns" ] = """Returns. Follows groups if cash sharing is enabled, otherwise columns. Gets filled if `save_returns` is True, otherwise has the shape `(0, 0)`.""" class FSInOutputs(tp.NamedTuple): cash: tp.Array2d position: tp.Array2d debt: tp.Array2d locked_cash: tp.Array2d free_cash: tp.Array2d value: tp.Array2d returns: tp.Array2d __pdoc__["FSInOutputs"] = "A named tuple representing the in-outputs for simulation based on signals." __pdoc__["FSInOutputs.cash"] = "See `FOInOutputs.cash`." __pdoc__["FSInOutputs.position"] = "See `FOInOutputs.position`." __pdoc__["FSInOutputs.debt"] = "See `FOInOutputs.debt`." __pdoc__["FSInOutputs.locked_cash"] = "See `FOInOutputs.locked_cash`." __pdoc__["FSInOutputs.free_cash"] = "See `FOInOutputs.free_cash`." __pdoc__["FSInOutputs.value"] = "See `FOInOutputs.value`." __pdoc__["FSInOutputs.returns"] = "See `FOInOutputs.returns`." # ############# Records ############# # order_fields = [ ("id", int_), ("col", int_), ("idx", int_), ("size", float_), ("price", float_), ("fees", float_), ("side", int_), ] """Fields for `order_dt`.""" order_dt = np.dtype(order_fields, align=True) """_""" __pdoc__[ "order_dt" ] = f"""`np.dtype` of order records. ```python {prettify(order_dt)} ``` """ fs_order_fields = [ ("id", int_), ("col", int_), ("signal_idx", int_), ("creation_idx", int_), ("idx", int_), ("size", float_), ("price", float_), ("fees", float_), ("side", int_), ("type", int_), ("stop_type", int_), ] """Fields for `fs_order_dt`.""" fs_order_dt = np.dtype(fs_order_fields, align=True) """_""" __pdoc__[ "fs_order_dt" ] = f"""`np.dtype` of order records generated from signals. ```python {prettify(fs_order_dt)} ``` """ trade_fields = [ ("id", int_), ("col", int_), ("size", float_), ("entry_order_id", int_), ("entry_idx", int_), ("entry_price", float_), ("entry_fees", float_), ("exit_order_id", int_), ("exit_idx", int_), ("exit_price", float_), ("exit_fees", float_), ("pnl", float_), ("return", float_), ("direction", int_), ("status", int_), ("parent_id", int_), ] """Fields for `trade_dt`.""" trade_dt = np.dtype(trade_fields, align=True) """_""" __pdoc__[ "trade_dt" ] = f"""`np.dtype` of trade records. ```python {prettify(trade_dt)} ``` """ log_fields = [ ("id", int_), ("group", int_), ("col", int_), ("idx", int_), ("price_area_open", float_), ("price_area_high", float_), ("price_area_low", float_), ("price_area_close", float_), ("st0_cash", float_), ("st0_position", float_), ("st0_debt", float_), ("st0_locked_cash", float_), ("st0_free_cash", float_), ("st0_val_price", float_), ("st0_value", float_), ("req_size", float_), ("req_price", float_), ("req_size_type", int_), ("req_direction", int_), ("req_fees", float_), ("req_fixed_fees", float_), ("req_slippage", float_), ("req_min_size", float_), ("req_max_size", float_), ("req_size_granularity", float_), ("req_leverage", float_), ("req_leverage_mode", int_), ("req_reject_prob", float_), ("req_price_area_vio_mode", int_), ("req_allow_partial", np.bool_), ("req_raise_reject", np.bool_), ("req_log", np.bool_), ("res_size", float_), ("res_price", float_), ("res_fees", float_), ("res_side", int_), ("res_status", int_), ("res_status_info", int_), ("st1_cash", float_), ("st1_position", float_), ("st1_debt", float_), ("st1_locked_cash", float_), ("st1_free_cash", float_), ("st1_val_price", float_), ("st1_value", float_), ("order_id", int_), ] """Fields for `log_fields`.""" log_dt = np.dtype(log_fields, align=True) """_""" __pdoc__[ "log_dt" ] = f"""`np.dtype` of log records. ```python {prettify(log_dt)} ``` """ alloc_range_fields = [ ("id", int_), ("col", int_), ("start_idx", int_), ("end_idx", int_), ("alloc_idx", int_), ("status", int_), ] """Fields for `alloc_range_dt`.""" alloc_range_dt = np.dtype(alloc_range_fields, align=True) """_""" __pdoc__[ "alloc_range_dt" ] = f"""`np.dtype` of allocation range records. ```python {prettify(alloc_range_dt)} ``` """ alloc_point_fields = [ ("id", int_), ("col", int_), ("alloc_idx", int_), ] """Fields for `alloc_point_dt`.""" alloc_point_dt = np.dtype(alloc_point_fields, align=True) """_""" __pdoc__[ "alloc_point_dt" ] = f"""`np.dtype` of allocation point records. ```python {prettify(alloc_point_dt)} ``` """ # ############# Info records ############# # main_info_fields = [ ("bar_zone", int_), ("signal_idx", int_), ("creation_idx", int_), ("idx", int_), ("val_price", float_), ("price", float_), ("size", float_), ("size_type", int_), ("direction", int_), ("type", int_), ("stop_type", int_), ] """Fields for `main_info_dt`.""" main_info_dt = np.dtype(main_info_fields, align=True) """_""" __pdoc__[ "main_info_dt" ] = f"""`np.dtype` of main information records. ```python {prettify(main_info_dt)} ``` Attributes: bar_zone: See `vectorbtpro.generic.enums.BarZone`. signal_idx: Row where signal was placed. creation_idx: Row where order was created. i: Row from where order information was taken. val_price: Valuation price. price: Requested price. size: Order size. size_type: See `SizeType`. direction: See `Direction`. type: See `OrderType`. stop_type: See `vectorbtpro.signals.enums.StopType`. """ limit_info_fields = [ ("signal_idx", int_), ("creation_idx", int_), ("init_idx", int_), ("init_price", float_), ("init_size", float_), ("init_size_type", int_), ("init_direction", int_), ("init_stop_type", int_), ("delta", float_), ("delta_format", int_), ("tif", np.int64), ("expiry", np.int64), ("time_delta_format", int_), ("reverse", float_), ("order_price", float_), ] """Fields for `limit_info_dt`.""" limit_info_dt = np.dtype(limit_info_fields, align=True) """_""" __pdoc__[ "limit_info_dt" ] = f"""`np.dtype` of limit information records. ```python {prettify(limit_info_dt)} ``` Attributes: signal_idx: Signal row. creation_idx: Limit creation row. init_idx: Initial row from where order information is taken. init_price: Initial price. init_size: Order size. init_size_type: See `SizeType`. init_direction: See `Direction`. init_stop_type: See `vectorbtpro.signals.enums.StopType`. delta: Delta from the initial price. delta_format: See `DeltaFormat`. tif: Time in force in integer format. Set to `-1` to disable. expiry: Expiry time in integer format. Set to `-1` to disable. time_delta_format: See `TimeDeltaFormat`. reverse: Whether to reverse the price hit detection. order_price: See `LimitOrderPrice`. """ sl_info_fields = [ ("init_idx", int_), ("init_price", float_), ("init_position", float_), ("stop", float_), ("exit_price", float_), ("exit_size", float_), ("exit_size_type", int_), ("exit_type", int_), ("order_type", int_), ("limit_delta", float_), ("delta_format", int_), ("ladder", int_), ("step", int_), ("step_idx", int_), ] """Fields for `sl_info_dt`.""" sl_info_dt = np.dtype(sl_info_fields, align=True) """_""" __pdoc__[ "sl_info_dt" ] = f"""`np.dtype` of SL information records. ```python {prettify(sl_info_dt)} ``` Attributes: init_idx: Initial row. init_price: Initial price. init_position: Initial position. stop: Latest updated stop value. exit_price: See `StopExitPrice`. exit_size: Order size. exit_size_type: See `SizeType`. exit_type: See `StopExitType`. order_type: See `OrderType`. limit_delta: Delta from the hit price. Only for `StopType.Limit`. delta_format: See `DeltaFormat`. ladder: See `StopLadderMode`. step: Step in the ladder (i.e., the number of times the stop was executed) step_idx: Step row. """ tsl_info_fields = [ ("init_idx", int_), ("init_price", float_), ("init_position", float_), ("peak_idx", int_), ("peak_price", float_), ("stop", float_), ("th", float_), ("exit_price", float_), ("exit_size", float_), ("exit_size_type", int_), ("exit_type", int_), ("order_type", int_), ("limit_delta", float_), ("delta_format", int_), ("ladder", int_), ("step", int_), ("step_idx", int_), ] """Fields for `tsl_info_dt`.""" tsl_info_dt = np.dtype(tsl_info_fields, align=True) """_""" __pdoc__[ "tsl_info_dt" ] = f"""`np.dtype` of TSL information records. ```python {prettify(tsl_info_dt)} ``` Attributes: init_idx: Initial row. init_price: Initial price. init_position: Initial position. peak_idx: Row of the highest/lowest price. peak_price: Highest/lowest price. stop: Latest updated stop value. th: Latest updated threshold value. exit_price: See `StopExitPrice`. exit_size: Order size. exit_size_type: See `SizeType`. exit_type: See `StopExitType`. order_type: See `OrderType`. limit_delta: Delta from the hit price. Only for `StopType.Limit`. delta_format: See `DeltaFormat`. ladder: See `StopLadderMode`. step: Step in the ladder (i.e., the number of times the stop was executed) step_idx: Step row. """ tp_info_fields = [ ("init_idx", int_), ("init_price", float_), ("init_position", float_), ("stop", float_), ("exit_price", float_), ("exit_size", float_), ("exit_size_type", int_), ("exit_type", int_), ("order_type", int_), ("limit_delta", float_), ("delta_format", int_), ("ladder", int_), ("step", int_), ("step_idx", int_), ] """Fields for `tp_info_dt`.""" tp_info_dt = np.dtype(tp_info_fields, align=True) """_""" __pdoc__[ "tp_info_dt" ] = f"""`np.dtype` of TP information records. ```python {prettify(tp_info_dt)} ``` Attributes: init_idx: Initial row. init_price: Initial price. init_position: Initial position. stop: Latest updated stop value. exit_price: See `StopExitPrice`. exit_size: Order size. exit_size_type: See `SizeType`. exit_type: See `StopExitType`. order_type: See `OrderType`. limit_delta: Delta from the hit price. Only for `StopType.Limit`. delta_format: See `DeltaFormat`. ladder: See `StopLadderMode`. step: Step in the ladder (i.e., the number of times the stop was executed) step_idx: Step row. """ time_info_fields = [ ("init_idx", int_), ("init_position", float_), ("stop", np.int64), ("exit_price", float_), ("exit_size", float_), ("exit_size_type", int_), ("exit_type", int_), ("order_type", int_), ("limit_delta", float_), ("delta_format", int_), ("time_delta_format", int_), ("ladder", int_), ("step", int_), ("step_idx", int_), ] """Fields for `time_info_dt`.""" time_info_dt = np.dtype(time_info_fields, align=True) """_""" __pdoc__[ "time_info_dt" ] = f"""`np.dtype` of time information records. ```python {prettify(time_info_dt)} ``` Attributes: init_idx: Initial row. init_position: Initial position. stop: Latest updated stop value. exit_price: See `StopExitPrice`. exit_size: Order size. exit_size_type: See `SizeType`. exit_type: See `StopExitType`. order_type: See `OrderType`. limit_delta: Delta from the hit price. Only for `StopType.Limit`. delta_format: See `DeltaFormat`. Only for `StopType.Limit`. time_delta_format: See `TimeDeltaFormat`. ladder: See `StopLadderMode`. step: Step in the ladder (i.e., the number of times the stop was executed) step_idx: Step row. """ # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with log records. Order records capture information on simulation logs. Logs are populated when simulating a portfolio and can be accessed as `vectorbtpro.portfolio.base.Portfolio.logs`. ```pycon >>> from vectorbtpro import * >>> np.random.seed(42) >>> price = pd.DataFrame({ ... 'a': np.random.uniform(1, 2, size=100), ... 'b': np.random.uniform(1, 2, size=100) ... }, index=[datetime(2020, 1, 1) + timedelta(days=i) for i in range(100)]) >>> size = pd.DataFrame({ ... 'a': np.random.uniform(-100, 100, size=100), ... 'b': np.random.uniform(-100, 100, size=100), ... }, index=[datetime(2020, 1, 1) + timedelta(days=i) for i in range(100)]) >>> pf = vbt.Portfolio.from_orders(price, size, fees=0.01, freq='d', log=True) >>> logs = pf.logs >>> logs.filled.count() a 88 b 99 Name: count, dtype: int64 >>> logs.ignored.count() a 0 b 0 Name: count, dtype: int64 >>> logs.rejected.count() a 12 b 1 Name: count, dtype: int64 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Logs.metrics`. ```pycon >>> logs['a'].stats() Start 2020-01-01 00:00:00 End 2020-04-09 00:00:00 Period 100 days 00:00:00 Total Records 100 Status Counts: None 0 Status Counts: Filled 88 Status Counts: Ignored 0 Status Counts: Rejected 12 Status Info Counts: None 88 Status Info Counts: NoCashLong 12 Name: a, dtype: object ``` `Logs.stats` also supports (re-)grouping: ```pycon >>> logs.stats(group_by=True) Start 2020-01-01 00:00:00 End 2020-04-09 00:00:00 Period 100 days 00:00:00 Total Records 200 Status Counts: None 0 Status Counts: Filled 187 Status Counts: Ignored 0 Status Counts: Rejected 13 Status Info Counts: None 187 Status Info Counts: NoCashLong 13 Name: group, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Logs.subplots`. This class does not have any subplots. """ import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_dict from vectorbtpro.generic.price_records import PriceRecords from vectorbtpro.portfolio.enums import ( log_dt, SizeType, LeverageMode, PriceAreaVioMode, Direction, OrderSide, OrderStatus, OrderStatusInfo, ) from vectorbtpro.records.decorators import attach_fields, override_field_config from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig __all__ = [ "Logs", ] __pdoc__ = {} logs_field_config = ReadonlyConfig( dict( dtype=log_dt, settings=dict( id=dict(title="Log Id"), col=dict(title="Column"), idx=dict(title="Index"), group=dict(title="Group"), price_area_open=dict(title="[PA] Open"), price_area_high=dict(title="[PA] High"), price_area_low=dict(title="[PA] Low"), price_area_close=dict(title="[PA] Close"), st0_cash=dict(title="[ST0] Cash"), st0_position=dict(title="[ST0] Position"), st0_debt=dict(title="[ST0] Debt"), st0_locked_cash=dict(title="[ST0] Locked Cash"), st0_free_cash=dict(title="[ST0] Free Cash"), st0_val_price=dict(title="[ST0] Valuation Price"), st0_value=dict(title="[ST0] Value"), req_size=dict(title="[REQ] Size"), req_price=dict(title="[REQ] Price"), req_size_type=dict(title="[REQ] Size Type", mapping=SizeType), req_direction=dict(title="[REQ] Direction", mapping=Direction), req_fees=dict(title="[REQ] Fees"), req_fixed_fees=dict(title="[REQ] Fixed Fees"), req_slippage=dict(title="[REQ] Slippage"), req_min_size=dict(title="[REQ] Min Size"), req_max_size=dict(title="[REQ] Max Size"), req_size_granularity=dict(title="[REQ] Size Granularity"), req_leverage=dict(title="[REQ] Leverage"), req_leverage_mode=dict(title="[REQ] Leverage Mode", mapping=LeverageMode), req_reject_prob=dict(title="[REQ] Rejection Prob"), req_price_area_vio_mode=dict(title="[REQ] Price Area Violation Mode", mapping=PriceAreaVioMode), req_allow_partial=dict(title="[REQ] Allow Partial"), req_raise_reject=dict(title="[REQ] Raise Rejection"), req_log=dict(title="[REQ] Log"), res_size=dict(title="[RES] Size"), res_price=dict(title="[RES] Price"), res_fees=dict(title="[RES] Fees"), res_side=dict(title="[RES] Side", mapping=OrderSide), res_status=dict(title="[RES] Status", mapping=OrderStatus), res_status_info=dict(title="[RES] Status Info", mapping=OrderStatusInfo), st1_cash=dict(title="[ST1] Cash"), st1_position=dict(title="[ST1] Position"), st1_debt=dict(title="[ST1] Debt"), st1_locked_cash=dict(title="[ST1] Locked Cash"), st1_free_cash=dict(title="[ST1] Free Cash"), st1_val_price=dict(title="[ST1] Valuation Price"), st1_value=dict(title="[ST1] Value"), order_id=dict(title="Order Id", mapping="ids"), ), ) ) """_""" __pdoc__[ "logs_field_config" ] = f"""Field config for `Logs`. ```python {logs_field_config.prettify()} ``` """ logs_attach_field_config = ReadonlyConfig( dict( res_side=dict(attach_filters=True), res_status=dict(attach_filters=True), res_status_info=dict(attach_filters=True), ) ) """_""" __pdoc__[ "logs_attach_field_config" ] = f"""Config of fields to be attached to `Logs`. ```python {logs_attach_field_config.prettify()} ``` """ LogsT = tp.TypeVar("LogsT", bound="Logs") @attach_fields(logs_attach_field_config) @override_field_config(logs_field_config) class Logs(PriceRecords): """Extends `vectorbtpro.generic.price_records.PriceRecords` for working with log records.""" @property def field_config(self) -> Config: return self._field_config # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Logs.stats`. Merges `vectorbtpro.generic.price_records.PriceRecords.stats_defaults` and `stats` from `vectorbtpro._settings.logs`.""" from vectorbtpro._settings import settings logs_stats_cfg = settings["logs"]["stats"] return merge_dicts(PriceRecords.stats_defaults.__get__(self), logs_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), total_records=dict(title="Total Records", calc_func="count", tags="records"), res_status_counts=dict( title="Status Counts", calc_func="res_status.value_counts", incl_all_keys=True, post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"), tags=["logs", "res_status", "value_counts"], ), res_status_info_counts=dict( title="Status Info Counts", calc_func="res_status_info.value_counts", post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"), tags=["logs", "res_status_info", "value_counts"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Logs.plots`. Merges `vectorbtpro.generic.price_records.PriceRecords.plots_defaults` and `plots` from `vectorbtpro._settings.logs`.""" from vectorbtpro._settings import settings logs_plots_cfg = settings["logs"]["plots"] return merge_dicts(PriceRecords.plots_defaults.__get__(self), logs_plots_cfg) @property def subplots(self) -> Config: return self._subplots Logs.override_field_config_doc(__pdoc__) Logs.override_metrics_doc(__pdoc__) Logs.override_subplots_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with order records. Order records capture information on filled orders. Orders are mainly populated when simulating a portfolio and can be accessed as `vectorbtpro.portfolio.base.Portfolio.orders`. ```pycon >>> from vectorbtpro import * >>> price = vbt.RandomData.pull( ... ['a', 'b'], ... start=datetime(2020, 1, 1), ... end=datetime(2020, 3, 1), ... seed=vbt.key_dict(a=42, b=43) ... ).get() ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> size = pd.DataFrame({ ... 'a': np.random.randint(-1, 2, size=len(price.index)), ... 'b': np.random.randint(-1, 2, size=len(price.index)), ... }, index=price.index, columns=price.columns) >>> pf = vbt.Portfolio.from_orders(price, size, fees=0.01, freq='d') >>> pf.orders.side_buy.count() symbol a 17 b 15 Name: count, dtype: int64 >>> pf.orders.side_sell.count() symbol a 24 b 26 Name: count, dtype: int64 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Orders.metrics`. ```pycon >>> pf.orders['a'].stats() Start 2019-12-31 22:00:00+00:00 End 2020-02-29 22:00:00+00:00 Period 61 days 00:00:00 Total Records 41 Side Counts: Buy 17 Side Counts: Sell 24 Size: Min 0 days 19:33:05.006182372 Size: Median 1 days 00:00:00 Size: Max 1 days 00:00:00 Fees: Min 0 days 20:26:25.905776572 Fees: Median 0 days 22:46:22.693324744 Fees: Max 1 days 01:04:25.541681491 Weighted Buy Price 94.69917 Weighted Sell Price 95.742148 Name: a, dtype: object ``` `Orders.stats` also supports (re-)grouping: ```pycon >>> pf.orders.stats(group_by=True) Start 2019-12-31 22:00:00+00:00 End 2020-02-29 22:00:00+00:00 Period 61 days 00:00:00 Total Records 82 Side Counts: Buy 32 Side Counts: Sell 50 Size: Min 0 days 19:33:05.006182372 Size: Median 1 days 00:00:00 Size: Max 1 days 00:00:00 Fees: Min 0 days 20:26:25.905776572 Fees: Median 0 days 23:58:29.773897679 Fees: Max 1 days 02:29:08.904770159 Weighted Buy Price 98.804452 Weighted Sell Price 99.969934 Name: group, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Orders.subplots`. `Orders` class has a single subplot based on `Orders.plot`: ```pycon >>> pf.orders['a'].plots().show() ``` ![](/assets/images/api/orders_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/orders_plots.dark.svg#only-dark){: .iimg loading=lazy } """ import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_1d_array, to_2d_array, to_dict from vectorbtpro.generic.enums import RangeStatus, range_dt from vectorbtpro.generic.price_records import PriceRecords from vectorbtpro.generic.ranges import Ranges from vectorbtpro.portfolio import nb from vectorbtpro.portfolio.enums import order_dt, OrderSide, fs_order_dt, OrderType, OrderPriceStatus from vectorbtpro.records.decorators import attach_fields, override_field_config, attach_shortcut_properties from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.signals.enums import StopType from vectorbtpro.utils.colors import adjust_lightness from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig __all__ = [ "Orders", "FSOrders", ] __pdoc__ = {} orders_field_config = ReadonlyConfig( dict( dtype=order_dt, settings=dict( id=dict(title="Order Id"), idx=dict(), size=dict(title="Size"), price=dict(title="Price"), fees=dict(title="Fees"), side=dict(title="Side", mapping=OrderSide, as_customdata=False), ), ) ) """_""" __pdoc__[ "orders_field_config" ] = f"""Field config for `Orders`. ```python {orders_field_config.prettify()} ``` """ orders_attach_field_config = ReadonlyConfig(dict(side=dict(attach_filters=True))) """_""" __pdoc__[ "orders_attach_field_config" ] = f"""Config of fields to be attached to `Orders`. ```python {orders_attach_field_config.prettify()} ``` """ orders_shortcut_config = ReadonlyConfig( dict( long_view=dict(), short_view=dict(), signed_size=dict(obj_type="mapped"), value=dict(obj_type="mapped"), weighted_price=dict(obj_type="red_array"), price_status=dict(obj_type="mapped"), ) ) """_""" __pdoc__[ "orders_shortcut_config" ] = f"""Config of shortcut properties to be attached to `Orders`. ```python {orders_shortcut_config.prettify()} ``` """ OrdersT = tp.TypeVar("OrdersT", bound="Orders") @attach_shortcut_properties(orders_shortcut_config) @attach_fields(orders_attach_field_config) @override_field_config(orders_field_config) class Orders(PriceRecords): """Extends `vectorbtpro.generic.price_records.PriceRecords` for working with order records.""" @property def field_config(self) -> Config: return self._field_config # ############# Views ############# # def get_long_view( self: OrdersT, init_position: tp.ArrayLike = 0.0, init_price: tp.ArrayLike = np.nan, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, ) -> OrdersT: """See `vectorbtpro.portfolio.nb.records.get_long_view_orders_nb`.""" func = jit_reg.resolve_option(nb.get_long_view_orders_nb, jitted) func = ch_reg.resolve_option(func, chunked) new_records_arr = func( self.records_arr, to_2d_array(self.close), self.col_mapper.col_map, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), ) return self.replace(records_arr=new_records_arr) def get_short_view( self: OrdersT, init_position: tp.ArrayLike = 0.0, init_price: tp.ArrayLike = np.nan, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, ) -> OrdersT: """See `vectorbtpro.portfolio.nb.records.get_short_view_orders_nb`.""" func = jit_reg.resolve_option(nb.get_short_view_orders_nb, jitted) func = ch_reg.resolve_option(func, chunked) new_records_arr = func( self.records_arr, to_2d_array(self.close), self.col_mapper.col_map, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), ) return self.replace(records_arr=new_records_arr) # ############# Stats ############# # def get_signed_size(self, **kwargs) -> tp.MaybeSeries: """Get signed size.""" size = self.get_field_arr("size").copy() size[self.get_field_arr("side") == OrderSide.Sell] *= -1 return self.map_array(size, **kwargs) def get_value(self, **kwargs) -> tp.MaybeSeries: """Get value.""" return self.map_array(self.signed_size.values * self.price.values, **kwargs) def get_weighted_price( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get size-weighted price average.""" wrap_kwargs = merge_dicts(dict(name_or_index="weighted_price"), wrap_kwargs) return MappedArray.reduce( nb.weighted_price_reduce_meta_nb, self.get_field_arr("size"), self.get_field_arr("price"), group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, col_mapper=self.col_mapper, **kwargs, ) def get_price_status(self, **kwargs) -> MappedArray: """See `vectorbtpro.portfolio.nb.records.price_status_nb`.""" return self.apply( nb.price_status_nb, self._high, self._low, mapping=OrderPriceStatus, **kwargs, ) @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Orders.stats`. Merges `vectorbtpro.generic.price_records.PriceRecords.stats_defaults` and `stats` from `vectorbtpro._settings.orders`.""" from vectorbtpro._settings import settings orders_stats_cfg = settings["orders"]["stats"] return merge_dicts(PriceRecords.stats_defaults.__get__(self), orders_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), total_records=dict(title="Total Records", calc_func="count", tags="records"), side_counts=dict( title="Side Counts", calc_func="side.value_counts", incl_all_keys=True, post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"), tags=["orders", "side"], ), size=dict( title="Size", calc_func="size.describe", post_calc_func=lambda self, out, settings: { "Min": out.loc["min"], "Median": out.loc["50%"], "Max": out.loc["max"], }, tags=["orders", "size"], ), fees=dict( title="Fees", calc_func="fees.describe", post_calc_func=lambda self, out, settings: { "Min": out.loc["min"], "Median": out.loc["50%"], "Max": out.loc["max"], }, tags=["orders", "fees"], ), weighted_buy_price=dict( title="Weighted Buy Price", calc_func="side_buy.get_weighted_price", tags=["orders", "buy", "price"], ), weighted_sell_price=dict( title="Weighted Sell Price", calc_func="side_sell.get_weighted_price", tags=["orders", "sell", "price"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, column: tp.Optional[tp.Label] = None, plot_ohlc: bool = True, plot_close: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, buy_trace_kwargs: tp.KwargsLike = None, sell_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot orders. Args: column (str): Name of the column to plot. plot_ohlc (bool): Whether to plot OHLC. plot_close (bool): Whether to plot close. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Orders.close`. buy_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Buy" markers. sell_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Sell" markers. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=5) >>> price = pd.Series([1., 2., 3., 2., 1.], index=index) >>> size = pd.Series([1., 1., 1., 1., -1.], index=index) >>> orders = vbt.Portfolio.from_orders(price, size).orders >>> orders.plot().show() ``` ![](/assets/images/api/orders_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/orders_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if buy_trace_kwargs is None: buy_trace_kwargs = {} if sell_trace_kwargs is None: sell_trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) # Plot price if ( plot_ohlc and self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc_df = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc_df.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close and self_col._close is not None: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Extract information idx = self_col.get_map_field_to_index("idx") price = self_col.get_field_arr("price") side = self_col.get_field_arr("side") buy_mask = side == OrderSide.Buy if buy_mask.any(): # Plot buy markers buy_customdata, buy_hovertemplate = self_col.prepare_customdata(mask=buy_mask) _buy_trace_kwargs = merge_dicts( dict( x=idx[buy_mask], y=price[buy_mask], mode="markers", marker=dict( symbol="triangle-up", color=plotting_cfg["contrast_color_schema"]["green"], size=8, line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["green"])), ), name="Buy", customdata=buy_customdata, hovertemplate=buy_hovertemplate, ), buy_trace_kwargs, ) buy_scatter = go.Scatter(**_buy_trace_kwargs) fig.add_trace(buy_scatter, **add_trace_kwargs) sell_mask = side == OrderSide.Sell if sell_mask.any(): # Plot sell markers sell_customdata, sell_hovertemplate = self_col.prepare_customdata(mask=sell_mask) _sell_trace_kwargs = merge_dicts( dict( x=idx[sell_mask], y=price[sell_mask], mode="markers", marker=dict( symbol="triangle-down", color=plotting_cfg["contrast_color_schema"]["red"], size=8, line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["red"])), ), name="Sell", customdata=sell_customdata, hovertemplate=sell_hovertemplate, ), sell_trace_kwargs, ) sell_scatter = go.Scatter(**_sell_trace_kwargs) fig.add_trace(sell_scatter, **add_trace_kwargs) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Orders.plots`. Merges `vectorbtpro.generic.price_records.PriceRecords.plots_defaults` and `plots` from `vectorbtpro._settings.orders`.""" from vectorbtpro._settings import settings orders_plots_cfg = settings["orders"]["plots"] return merge_dicts(PriceRecords.plots_defaults.__get__(self), orders_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( title="Orders", yaxis_kwargs=dict(title="Price"), check_is_not_grouped=True, plot_func="plot", tags="orders", ) ) ) @property def subplots(self) -> Config: return self._subplots Orders.override_field_config_doc(__pdoc__) Orders.override_metrics_doc(__pdoc__) Orders.override_subplots_doc(__pdoc__) fs_orders_field_config = ReadonlyConfig( dict( dtype=fs_order_dt, settings=dict( idx=dict(title="Fill Index"), signal_idx=dict( title="Signal Index", mapping="index", noindex=True, ), creation_idx=dict( title="Creation Index", mapping="index", noindex=True, ), type=dict( title="Type", mapping=OrderType, ), stop_type=dict( title="Stop Type", mapping=StopType, ), ), ) ) """_""" __pdoc__[ "fs_orders_field_config" ] = f"""Field config for `FSOrders`. ```python {fs_orders_field_config.prettify()} ``` """ fs_orders_attach_field_config = ReadonlyConfig( dict( type=dict(attach_filters=True), stop_type=dict(attach_filters=True), ) ) """_""" __pdoc__[ "fs_orders_attach_field_config" ] = f"""Config of fields to be attached to `FSOrders`. ```python {fs_orders_attach_field_config.prettify()} ``` """ fs_orders_shortcut_config = ReadonlyConfig( dict( stop_orders=dict(), ranges=dict(), creation_ranges=dict(), fill_ranges=dict(), signal_to_creation_duration=dict(obj_type="mapped_array"), creation_to_fill_duration=dict(obj_type="mapped_array"), signal_to_fill_duration=dict(obj_type="mapped_array"), ) ) """_""" __pdoc__[ "fs_orders_shortcut_config" ] = f"""Config of shortcut properties to be attached to `FSOrders`. ```python {fs_orders_shortcut_config.prettify()} ``` """ FSOrdersT = tp.TypeVar("FSOrdersT", bound="FSOrders") @attach_shortcut_properties(fs_orders_shortcut_config) @attach_fields(fs_orders_attach_field_config) @override_field_config(fs_orders_field_config) class FSOrders(Orders): """Extends `Orders` for working with order records generated from signals.""" @property def field_config(self) -> Config: return self._field_config def get_stop_orders(self, **kwargs): """Get stop orders.""" return self.apply_mask(self.get_field_arr("stop_type") != -1, **kwargs) def get_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges` for signal-to-fill ranges.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("signal_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("idx").copy() new_records_arr["status"][:] = RangeStatus.Closed return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) def get_creation_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges` for signal-to-creation ranges.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("signal_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("creation_idx").copy() new_records_arr["status"][:] = RangeStatus.Closed return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) def get_fill_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges` for creation-to-fill ranges.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("creation_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("idx").copy() new_records_arr["status"][:] = RangeStatus.Closed return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) def get_signal_to_creation_duration(self, **kwargs) -> MappedArray: """Get duration between signal and creation.""" duration = self.get_field_arr("creation_idx") - self.get_field_arr("signal_idx") return self.map_array(duration, **kwargs) def get_creation_to_fill_duration(self, **kwargs) -> MappedArray: """Get duration between creation and fill.""" duration = self.get_field_arr("idx") - self.get_field_arr("creation_idx") return self.map_array(duration, **kwargs) def get_signal_to_fill_duration(self, **kwargs) -> MappedArray: """Get duration between signal and fill.""" duration = self.get_field_arr("idx") - self.get_field_arr("signal_idx") return self.map_array(duration, **kwargs) _metrics: tp.ClassVar[Config] = HybridConfig( start_index=Orders.metrics["start_index"], end_index=Orders.metrics["end_index"], total_duration=Orders.metrics["total_duration"], total_records=Orders.metrics["total_records"], side_counts=Orders.metrics["side_counts"], type_counts=dict( title="Type Counts", calc_func="type.value_counts", incl_all_keys=True, post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"), tags=["orders", "type"], ), stop_type_counts=dict( title="Stop Type Counts", calc_func="stop_type.value_counts", incl_all_keys=True, post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"), tags=["orders", "stop_type"], ), size=Orders.metrics["size"], fees=Orders.metrics["fees"], weighted_buy_price=Orders.metrics["weighted_buy_price"], weighted_sell_price=Orders.metrics["weighted_sell_price"], avg_signal_to_creation_duration=dict( title="Avg Signal-Creation Duration", calc_func="signal_to_creation_duration.mean", apply_to_timedelta=True, tags=["orders", "duration"], ), avg_creation_to_fill_duration=dict( title="Avg Creation-Fill Duration", calc_func="creation_to_fill_duration.mean", apply_to_timedelta=True, tags=["orders", "duration"], ), avg_signal_to_fill_duration=dict( title="Avg Signal-Fill Duration", calc_func="signal_to_fill_duration.mean", apply_to_timedelta=True, tags=["orders", "duration"], ), ) @property def metrics(self) -> Config: return self._metrics FSOrders.override_field_config_doc(__pdoc__) FSOrders.override_metrics_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for preparing portfolio simulations.""" from collections import namedtuple from functools import cached_property as cachedproperty import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.decorators import override_arg_config, attach_arg_properties from vectorbtpro.base.preparing import BasePreparer from vectorbtpro.base.reshaping import to_2d_array, broadcast_array_to, broadcast from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.data.base import OHLCDataMixin, Data from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.sim_range import SimRangeMixin from vectorbtpro.portfolio import nb, enums from vectorbtpro.portfolio.call_seq import require_call_seq, build_call_seq from vectorbtpro.portfolio.orders import FSOrders from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg, register_jitted from vectorbtpro.signals import nb as signals_nb from vectorbtpro.utils import checks, chunking as ch from vectorbtpro.utils.config import Configured, merge_dicts, ReadonlyConfig from vectorbtpro.utils.mapping import to_field_mapping from vectorbtpro.utils.template import CustomTemplate, substitute_templates, RepFunc from vectorbtpro.utils.warnings_ import warn __all__ = [ "PFPrepResult", "BasePFPreparer", "FOPreparer", "FSPreparer", "FOFPreparer", "FDOFPreparer", ] __pdoc__ = {} @register_jitted(cache=True) def valid_price_from_ago_1d_nb(price: tp.Array1d) -> tp.Array1d: """Parse from_ago from a valid price.""" from_ago = np.empty(price.shape, dtype=int_) for i in range(price.shape[0] - 1, -1, -1): if i > 0 and not np.isnan(price[i]): for j in range(i - 1, -1, -1): if not np.isnan(price[j]): break from_ago[i] = i - j else: from_ago[i] = 1 return from_ago PFPrepResultT = tp.TypeVar("PFPrepResultT", bound="PFPrepResult") class PFPrepResult(Configured): """Result of preparation.""" def __init__( self, target_func: tp.Optional[tp.Callable] = None, target_args: tp.Optional[tp.Kwargs] = None, pf_args: tp.Optional[tp.Kwargs] = None, **kwargs, ) -> None: Configured.__init__( self, target_func=target_func, target_args=target_args, pf_args=pf_args, **kwargs, ) @cachedproperty def target_func(self) -> tp.Optional[tp.Callable]: """Target function.""" return self.config["target_func"] @cachedproperty def target_args(self) -> tp.Kwargs: """Target arguments.""" return self.config["target_args"] @cachedproperty def pf_args(self) -> tp.Optional[tp.Kwargs]: """Portfolio arguments.""" return self.config["pf_args"] base_arg_config = ReadonlyConfig( dict( data=dict(), open=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), high=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), low=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), close=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), bm_close=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), cash_earnings=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), init_cash=dict(map_enum_kwargs=dict(enum=enums.InitCashMode, look_for_type=str)), init_position=dict(), init_price=dict(), cash_deposits=dict(), group_by=dict(), cash_sharing=dict(), freq=dict(), sim_start=dict(), sim_end=dict(), call_seq=dict(map_enum_kwargs=dict(enum=enums.CallSeqType, look_for_type=str)), attach_call_seq=dict(), keep_inout_flex=dict(), in_outputs=dict(has_default=False), ) ) """_""" __pdoc__[ "base_arg_config" ] = f"""Argument config for `BasePFPreparer`. ```python {base_arg_config.prettify()} ``` """ @attach_arg_properties @override_arg_config(base_arg_config) class BasePFPreparer(BasePreparer): """Base class for preparing portfolio simulations.""" _settings_path: tp.SettingsPath = "portfolio" @classmethod def find_target_func(cls, target_func_name: str) -> tp.Callable: return getattr(nb, target_func_name) # ############# Ready arguments ############# # @cachedproperty def init_cash_mode(self) -> tp.Optional[int]: """Initial cash mode.""" init_cash = self["init_cash"] if checks.is_int(init_cash) and init_cash in enums.InitCashMode: return init_cash return None @cachedproperty def group_by(self) -> tp.GroupByLike: """Argument `group_by`.""" group_by = self["group_by"] if group_by is None and self.cash_sharing: return True return group_by @cachedproperty def auto_call_seq(self) -> bool: """Whether automatic call sequence is enabled.""" call_seq = self["call_seq"] return checks.is_int(call_seq) and call_seq == enums.CallSeqType.Auto @classmethod def parse_data( cls, data: tp.Union[None, OHLCDataMixin, str, tp.ArrayLike], all_ohlc: bool = False, ) -> tp.Optional[OHLCDataMixin]: """Parse an instance with OHLC features.""" if data is None: return None if isinstance(data, OHLCDataMixin): return data if isinstance(data, str): return Data.from_data_str(data) if isinstance(data, pd.DataFrame): ohlcv_acc = data.vbt.ohlcv if all_ohlc and ohlcv_acc.has_ohlc: return ohlcv_acc if not all_ohlc and ohlcv_acc.has_any_ohlc: return ohlcv_acc return None @cachedproperty def data(self) -> tp.Optional[OHLCDataMixin]: """Argument `data`.""" return self.parse_data(self["data"]) # ############# Before broadcasting ############# # @cachedproperty def _pre_open(self) -> tp.ArrayLike: """Argument `open` before broadcasting.""" open = self["open"] if open is None: if self.data is not None: open = self.data.open if open is None: return np.nan return open @cachedproperty def _pre_high(self) -> tp.ArrayLike: """Argument `high` before broadcasting.""" high = self["high"] if high is None: if self.data is not None: high = self.data.high if high is None: return np.nan return high @cachedproperty def _pre_low(self) -> tp.ArrayLike: """Argument `low` before broadcasting.""" low = self["low"] if low is None: if self.data is not None: low = self.data.low if low is None: return np.nan return low @cachedproperty def _pre_close(self) -> tp.ArrayLike: """Argument `close` before broadcasting.""" close = self["close"] if close is None: if self.data is not None: close = self.data.close if close is None: return np.nan return close @cachedproperty def _pre_bm_close(self) -> tp.Optional[tp.ArrayLike]: """Argument `bm_close` before broadcasting.""" bm_close = self["bm_close"] if bm_close is not None and not isinstance(bm_close, bool): return bm_close return np.nan @cachedproperty def _pre_init_cash(self) -> tp.ArrayLike: """Argument `init_cash` before broadcasting.""" if self.init_cash_mode is not None: return np.inf return self["init_cash"] @cachedproperty def _pre_init_position(self) -> tp.ArrayLike: """Argument `init_position` before broadcasting.""" return self["init_position"] @cachedproperty def _pre_init_price(self) -> tp.ArrayLike: """Argument `init_price` before broadcasting.""" return self["init_price"] @cachedproperty def _pre_cash_deposits(self) -> tp.ArrayLike: """Argument `cash_deposits` before broadcasting.""" return self["cash_deposits"] @cachedproperty def _pre_freq(self) -> tp.Optional[tp.FrequencyLike]: """Argument `freq` before casting to nanosecond format.""" freq = self["freq"] if freq is None and self.data is not None: return self.data.symbol_wrapper.freq return freq @cachedproperty def _pre_call_seq(self) -> tp.Optional[tp.ArrayLike]: """Argument `call_seq` before broadcasting.""" if self.auto_call_seq: return None return self["call_seq"] @cachedproperty def _pre_in_outputs(self) -> tp.Union[None, tp.NamedTuple, CustomTemplate]: """Argument `in_outputs` before broadcasting.""" in_outputs = self["in_outputs"] if ( in_outputs is not None and not isinstance(in_outputs, CustomTemplate) and not checks.is_namedtuple(in_outputs) ): in_outputs = to_field_mapping(in_outputs) in_outputs = namedtuple("InOutputs", in_outputs)(**in_outputs) return in_outputs # ############# After broadcasting ############# # @cachedproperty def cs_group_lens(self) -> tp.GroupLens: """Cash sharing aware group lengths.""" cs_group_lens = self.wrapper.grouper.get_group_lens(group_by=None if self.cash_sharing else False) checks.assert_subdtype(cs_group_lens, np.integer, arg_name="cs_group_lens") return cs_group_lens @cachedproperty def group_lens(self) -> tp.GroupLens: """Group lengths.""" return self.wrapper.grouper.get_group_lens(group_by=self.group_by) @cachedproperty def sim_group_lens(self) -> tp.GroupLens: """Simulation group lengths.""" return self.group_lens def align_pc_arr( self, arr: tp.ArrayLike, group_lens: tp.Optional[tp.GroupLens] = None, check_dtype: tp.Optional[tp.DTypeLike] = None, cast_to_dtype: tp.Optional[tp.DTypeLike] = None, reduce_func: tp.Union[None, str, tp.Callable] = None, arg_name: tp.Optional[str] = None, ) -> tp.Array1d: """Align a per-column array.""" arr = np.asarray(arr) if check_dtype is not None: checks.assert_subdtype(arr, check_dtype, arg_name=arg_name) if cast_to_dtype is not None: arr = np.require(arr, dtype=cast_to_dtype) if arr.size > 1 and group_lens is not None and reduce_func is not None: if len(self.group_lens) == len(arr) != len(group_lens) == len(self.wrapper.columns): new_arr = np.empty(len(self.wrapper.columns), dtype=int_) col_generator = self.wrapper.grouper.iter_group_idxs() for i, cols in enumerate(col_generator): new_arr[cols] = arr[i] arr = new_arr if len(self.wrapper.columns) == len(arr) != len(group_lens): new_arr = np.empty(len(group_lens), dtype=int_) col_generator = self.wrapper.grouper.iter_group_lens(group_lens) for i, cols in enumerate(col_generator): if isinstance(reduce_func, str): new_arr[i] = getattr(arr[cols], reduce_func)() else: new_arr[i] = reduce_func(arr[cols]) arr = new_arr if group_lens is not None: return broadcast_array_to(arr, len(group_lens)) return broadcast_array_to(arr, len(self.wrapper.columns)) @cachedproperty def init_cash(self) -> tp.Array1d: """Argument `init_cash`.""" return self.align_pc_arr( self._pre_init_cash, group_lens=self.cs_group_lens, check_dtype=np.number, cast_to_dtype=float_, reduce_func="sum", arg_name="init_cash", ) @cachedproperty def init_position(self) -> tp.Array1d: """Argument `init_position`.""" init_position = self.align_pc_arr( self._pre_init_position, check_dtype=np.number, cast_to_dtype=float_, arg_name="init_position", ) if (((init_position > 0) | (init_position < 0)) & np.isnan(self.init_price)).any(): warn(f"Initial position has undefined price. Set init_price.") return init_position @cachedproperty def init_price(self) -> tp.Array1d: """Argument `init_price`.""" return self.align_pc_arr( self._pre_init_price, check_dtype=np.number, cast_to_dtype=float_, arg_name="init_price", ) @cachedproperty def cash_deposits(self) -> tp.ArrayLike: """Argument `cash_deposits`.""" cash_deposits = self["cash_deposits"] checks.assert_subdtype(cash_deposits, np.number, arg_name="cash_deposits") return broadcast( cash_deposits, to_shape=(self.target_shape[0], len(self.cs_group_lens)), to_pd=False, keep_flex=self.keep_inout_flex, reindex_kwargs=dict(fill_value=0.0), require_kwargs=self.broadcast_kwargs.get("require_kwargs", {}), ) @cachedproperty def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]: """Get automatic `sim_start`""" return None @cachedproperty def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]: """Get automatic `sim_end`""" return None @cachedproperty def sim_start(self) -> tp.Optional[tp.ArrayLike]: """Argument `sim_start`.""" sim_start = self["sim_start"] if sim_start is None: return None if isinstance(sim_start, str) and sim_start.lower() == "auto": sim_start = self.auto_sim_start if sim_start is None: return None sim_start_arr = np.asarray(sim_start) if np.issubdtype(sim_start_arr.dtype, np.integer): if sim_start_arr.ndim == 0: return sim_start new_sim_start = sim_start_arr else: if sim_start_arr.ndim == 0: return SimRangeMixin.resolve_sim_start_value(sim_start, wrapper=self.wrapper) new_sim_start = np.empty(len(sim_start), dtype=int_) for i in range(len(sim_start)): new_sim_start[i] = SimRangeMixin.resolve_sim_start_value(sim_start[i], wrapper=self.wrapper) return self.align_pc_arr( new_sim_start, group_lens=self.sim_group_lens, check_dtype=np.integer, cast_to_dtype=int_, reduce_func="min", arg_name="sim_start", ) @cachedproperty def sim_end(self) -> tp.Optional[tp.ArrayLike]: """Argument `sim_end`.""" sim_end = self["sim_end"] if sim_end is None: return None if isinstance(sim_end, str) and sim_end.lower() == "auto": sim_end = self.auto_sim_end if sim_end is None: return None sim_end_arr = np.asarray(sim_end) if np.issubdtype(sim_end_arr.dtype, np.integer): if sim_end_arr.ndim == 0: return sim_end new_sim_end = sim_end_arr else: if sim_end_arr.ndim == 0: return SimRangeMixin.resolve_sim_end_value(sim_end, wrapper=self.wrapper) new_sim_end = np.empty(len(sim_end), dtype=int_) for i in range(len(sim_end)): new_sim_end[i] = SimRangeMixin.resolve_sim_end_value(sim_end[i], wrapper=self.wrapper) return self.align_pc_arr( new_sim_end, group_lens=self.sim_group_lens, check_dtype=np.integer, cast_to_dtype=int_, reduce_func="max", arg_name="sim_end", ) @cachedproperty def call_seq(self) -> tp.Optional[tp.ArrayLike]: """Argument `call_seq`.""" call_seq = self._pre_call_seq if call_seq is None and self.attach_call_seq: call_seq = enums.CallSeqType.Default if call_seq is not None: if checks.is_any_array(call_seq): call_seq = require_call_seq(broadcast(call_seq, to_shape=self.target_shape, to_pd=False)) else: call_seq = build_call_seq(self.target_shape, self.group_lens, call_seq_type=call_seq) if call_seq is not None: checks.assert_subdtype(call_seq, np.integer, arg_name="call_seq") return call_seq # ############# Template substitution ############# # @cachedproperty def template_context(self) -> tp.Kwargs: return merge_dicts( dict( group_lens=self.group_lens, cs_group_lens=self.cs_group_lens, cash_sharing=self.cash_sharing, init_cash=self.init_cash, init_position=self.init_position, init_price=self.init_price, cash_deposits=self.cash_deposits, sim_start=self.sim_start, sim_end=self.sim_end, call_seq=self.call_seq, auto_call_seq=self.auto_call_seq, attach_call_seq=self.attach_call_seq, in_outputs=self._pre_in_outputs, ), BasePreparer.template_context.func(self), ) @cachedproperty def in_outputs(self) -> tp.Optional[tp.NamedTuple]: """Argument `in_outputs`.""" return substitute_templates(self._pre_in_outputs, self.template_context, eval_id="in_outputs") # ############# Result ############# # @cachedproperty def pf_args(self) -> tp.Optional[tp.Kwargs]: """Arguments to be passed to the portfolio.""" kwargs = dict() for k, v in self.config.items(): if k not in self.arg_config and k != "arg_config": kwargs[k] = v return dict( wrapper=self.wrapper, open=self.open if self._pre_open is not np.nan else None, high=self.high if self._pre_high is not np.nan else None, low=self.low if self._pre_low is not np.nan else None, close=self.close, cash_sharing=self.cash_sharing, init_cash=self.init_cash if self.init_cash_mode is None else self.init_cash_mode, init_position=self.init_position, init_price=self.init_price, bm_close=( self.bm_close if (self["bm_close"] is not None and not isinstance(self["bm_close"], bool)) else self["bm_close"] ), **kwargs, ) @cachedproperty def result(self) -> PFPrepResult: """Result as an instance of `PFPrepResult`.""" return PFPrepResult(target_func=self.target_func, target_args=self.target_args, pf_args=self.pf_args) BasePFPreparer.override_arg_config_doc(__pdoc__) order_arg_config = ReadonlyConfig( dict( size=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), fill_default=False, ), price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.PriceType, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PriceType.Close)), ), size_type=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.SizeType), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.SizeType.Amount)), ), direction=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.Direction), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.Direction.Both)), ), fees=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), fixed_fees=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), slippage=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), min_size=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), max_size=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), size_granularity=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), leverage=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=1.0)), ), leverage_mode=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.LeverageMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.LeverageMode.Lazy)), ), reject_prob=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), price_area_vio_mode=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.PriceAreaVioMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PriceAreaVioMode.Ignore)), ), allow_partial=dict( broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=True)), ), raise_reject=dict( broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), log=dict( broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), ) ) """_""" __pdoc__[ "order_arg_config" ] = f"""Argument config for order-related information. ```python {order_arg_config.prettify()} ``` """ fo_arg_config = ReadonlyConfig( dict( cash_dividends=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), val_price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.ValPriceType, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), from_ago=dict( broadcast=True, subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0)), ), ffill_val_price=dict(), update_value=dict(), save_state=dict(), save_value=dict(), save_returns=dict(), skip_empty=dict(), max_order_records=dict(), max_log_records=dict(), ) ) """_""" __pdoc__[ "fo_arg_config" ] = f"""Argument config for `FOPreparer`. ```python {fo_arg_config.prettify()} ``` """ @attach_arg_properties @override_arg_config(fo_arg_config) @override_arg_config(order_arg_config) class FOPreparer(BasePFPreparer): """Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_orders`.""" _settings_path: tp.SettingsPath = "portfolio.from_orders" # ############# Ready arguments ############# # @cachedproperty def staticized(self) -> tp.StaticizedOption: """Argument `staticized`.""" raise ValueError("This method doesn't support staticization") # ############# Before broadcasting ############# # @cachedproperty def _pre_from_ago(self) -> tp.ArrayLike: """Argument `from_ago` before broadcasting.""" from_ago = self["from_ago"] if from_ago is not None: return from_ago return 0 @cachedproperty def _pre_max_order_records(self) -> tp.Optional[int]: """Argument `max_order_records` before broadcasting.""" return self["max_order_records"] @cachedproperty def _pre_max_log_records(self) -> tp.Optional[int]: """Argument `max_log_records` before broadcasting.""" return self["max_log_records"] # ############# After broadcasting ############# # @cachedproperty def sim_group_lens(self) -> tp.GroupLens: return self.cs_group_lens @cachedproperty def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]: size = to_2d_array(self.size) if size.shape[0] == 1: return None first_valid_idx = generic_nb.first_valid_index_nb(size, check_inf=False) first_valid_idx = np.where(first_valid_idx == -1, 0, first_valid_idx) return first_valid_idx @cachedproperty def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]: size = to_2d_array(self.size) if size.shape[0] == 1: return None last_valid_idx = generic_nb.last_valid_index_nb(size, check_inf=False) last_valid_idx = np.where(last_valid_idx == -1, len(self.wrapper.index), last_valid_idx + 1) return last_valid_idx @cachedproperty def price_and_from_ago(self) -> tp.Tuple[tp.ArrayLike, tp.ArrayLike]: """Arguments `price` and `from_ago` after broadcasting.""" price = self._post_price from_ago = self._post_from_ago if self["from_ago"] is None: if price.size == 1 or price.shape[0] == 1: next_open_mask = price == enums.PriceType.NextOpen next_close_mask = price == enums.PriceType.NextClose next_valid_open_mask = price == enums.PriceType.NextValidOpen next_valid_close_mask = price == enums.PriceType.NextValidClose if next_valid_open_mask.any() or next_valid_close_mask.any(): new_price = np.empty(self.wrapper.shape_2d, float_) new_from_ago = np.empty(self.wrapper.shape_2d, int_) if next_valid_open_mask.any(): open = broadcast_array_to(self.open, self.wrapper.shape_2d) if next_valid_close_mask.any(): close = broadcast_array_to(self.close, self.wrapper.shape_2d) for i in range(price.size): price_item = price.item(i) if price_item == enums.PriceType.NextOpen: new_price[:, i] = enums.PriceType.Open new_from_ago[:, i] = 1 elif price_item == enums.PriceType.NextClose: new_price[:, i] = enums.PriceType.Close new_from_ago[:, i] = 1 elif price_item == enums.PriceType.NextValidOpen: new_price[:, i] = enums.PriceType.Open new_from_ago[:, i] = valid_price_from_ago_1d_nb(open[:, i]) elif price_item == enums.PriceType.NextValidClose: new_price[:, i] = enums.PriceType.Close new_from_ago[:, i] = valid_price_from_ago_1d_nb(close[:, i]) price = new_price from_ago = new_from_ago elif next_open_mask.any() or next_close_mask.any(): price = price.astype(float_) price[next_open_mask] = enums.PriceType.Open price[next_close_mask] = enums.PriceType.Close from_ago = np.full(price.shape, 0, dtype=int_) from_ago[next_open_mask] = 1 from_ago[next_close_mask] = 1 return price, from_ago @cachedproperty def price(self) -> tp.ArrayLike: """Argument `price`.""" return self.price_and_from_ago[0] @cachedproperty def from_ago(self) -> tp.ArrayLike: """Argument `from_ago`.""" return self.price_and_from_ago[1] @cachedproperty def max_order_records(self) -> tp.Optional[int]: """Argument `max_order_records`.""" max_order_records = self._pre_max_order_records if max_order_records is None: _size = self._post_size if _size.size == 1: max_order_records = self.target_shape[0] * int(not np.isnan(_size.item(0))) else: if _size.shape[0] == 1 and self.target_shape[0] > 1: max_order_records = self.target_shape[0] * int(np.any(~np.isnan(_size))) else: max_order_records = int(np.max(np.sum(~np.isnan(_size), axis=0))) return max_order_records @cachedproperty def max_log_records(self) -> tp.Optional[int]: """Argument `max_log_records`.""" max_log_records = self._pre_max_log_records if max_log_records is None: _log = self._post_log if _log.size == 1: max_log_records = self.target_shape[0] * int(_log.item(0)) else: if _log.shape[0] == 1 and self.target_shape[0] > 1: max_log_records = self.target_shape[0] * int(np.any(_log)) else: max_log_records = int(np.max(np.sum(_log, axis=0))) return max_log_records # ############# Template substitution ############# # @cachedproperty def template_context(self) -> tp.Kwargs: return merge_dicts( dict( group_lens=self.group_lens if self.dynamic_mode else self.cs_group_lens, ffill_val_price=self.ffill_val_price, update_value=self.update_value, save_state=self.save_state, save_value=self.save_value, save_returns=self.save_returns, max_order_records=self.max_order_records, max_log_records=self.max_log_records, ), BasePFPreparer.template_context.func(self), ) # ############# Result ############# # @cachedproperty def target_func(self) -> tp.Optional[tp.Callable]: func = jit_reg.resolve_option(nb.from_orders_nb, self.jitted) func = ch_reg.resolve_option(func, self.chunked) return func @cachedproperty def target_arg_map(self) -> tp.Kwargs: target_arg_map = dict(BasePFPreparer.target_arg_map.func(self)) target_arg_map["group_lens"] = "cs_group_lens" return target_arg_map FOPreparer.override_arg_config_doc(__pdoc__) fs_arg_config = ReadonlyConfig( dict( size=dict( fill_default=True, ), cash_dividends=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)), ), entries=dict( has_default=False, broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), exits=dict( has_default=False, broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), long_entries=dict( has_default=False, broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), long_exits=dict( has_default=False, broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), short_entries=dict( has_default=False, broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), short_exits=dict( has_default=False, broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), adjust_func_nb=dict(), adjust_args=dict(type="args", substitute_templates=True), signal_func_nb=dict(), signal_args=dict(type="args", substitute_templates=True), post_signal_func_nb=dict(), post_signal_args=dict(type="args", substitute_templates=True), post_segment_func_nb=dict(), post_segment_args=dict(type="args", substitute_templates=True), order_mode=dict(), val_price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.ValPriceType, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), accumulate=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.AccumulationMode, ignore_type=(int, bool)), subdtype=(np.integer, np.bool_), broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.AccumulationMode.Disabled)), ), upon_long_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.ConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.ConflictMode.Ignore)), ), upon_short_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.ConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.ConflictMode.Ignore)), ), upon_dir_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.DirectionConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.DirectionConflictMode.Ignore)), ), upon_opposite_entry=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.OppositeEntryMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.OppositeEntryMode.ReverseReduce)), ), order_type=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.OrderType), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.OrderType.Market)), ), limit_delta=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), limit_tif=dict( broadcast=True, is_td=True, subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)), ), limit_expiry=dict( broadcast=True, is_dt=True, last_before=False, subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)), ), limit_reverse=dict( broadcast=True, subdtype=np.bool_, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)), ), limit_order_price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.LimitOrderPrice, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.LimitOrderPrice.Limit)), ), upon_adj_limit_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.PendingConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.KeepIgnore)), ), upon_opp_limit_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.PendingConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.CancelExecute)), ), use_stops=dict(), stop_ladder=dict(map_enum_kwargs=dict(enum=enums.StopLadderMode, look_for_type=str)), sl_stop=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), tsl_stop=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), tsl_th=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), tp_stop=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), td_stop=dict( broadcast=True, is_td=True, subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)), ), dt_stop=dict( broadcast=True, is_dt=True, subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)), ), stop_entry_price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.StopEntryPrice, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopEntryPrice.Close)), ), stop_exit_price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.StopExitPrice, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopExitPrice.Stop)), ), stop_exit_type=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.StopExitType), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopExitType.Close)), ), stop_order_type=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.OrderType), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.OrderType.Market)), ), stop_limit_delta=dict( broadcast=True, subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), upon_stop_update=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.StopUpdateMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopUpdateMode.Override)), ), upon_adj_stop_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.PendingConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.KeepExecute)), ), upon_opp_stop_conflict=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.PendingConflictMode), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.KeepExecute)), ), delta_format=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.DeltaFormat), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.DeltaFormat.Percent)), ), time_delta_format=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.TimeDeltaFormat), subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.TimeDeltaFormat.Index)), ), from_ago=dict( broadcast=True, subdtype=np.integer, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0)), ), ffill_val_price=dict(), update_value=dict(), fill_pos_info=dict(), save_state=dict(), save_value=dict(), save_returns=dict(), skip_empty=dict(), max_order_records=dict(), max_log_records=dict(), records=dict( rename_fields=dict( entry="entries", exit="exits", long_entry="long_entries", long_exit="long_exits", short_entry="short_entries", short_exit="short_exits", ) ), ) ) """_""" __pdoc__[ "fs_arg_config" ] = f"""Argument config for `FSPreparer`. ```python {fs_arg_config.prettify()} ``` """ @attach_arg_properties @override_arg_config(fs_arg_config) @override_arg_config(order_arg_config) class FSPreparer(BasePFPreparer): """Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_signals`.""" _settings_path: tp.SettingsPath = "portfolio.from_signals" # ############# Mode resolution ############# # @cachedproperty def _pre_staticized(self) -> tp.StaticizedOption: """Argument `staticized` before its resolution.""" staticized = self["staticized"] if isinstance(staticized, bool): if staticized: staticized = dict() else: staticized = None if isinstance(staticized, dict): staticized = dict(staticized) if "func" not in staticized: staticized["func"] = nb.from_signal_func_nb return staticized @cachedproperty def order_mode(self) -> bool: """Argument `order_mode`.""" order_mode = self["order_mode"] if order_mode is None: order_mode = False return order_mode @cachedproperty def dynamic_mode(self) -> tp.StaticizedOption: """Whether the dynamic mode is enabled.""" return ( self["adjust_func_nb"] is not None or self["signal_func_nb"] is not None or self["post_signal_func_nb"] is not None or self["post_segment_func_nb"] is not None or self.order_mode or self._pre_staticized is not None ) @cachedproperty def implicit_mode(self) -> bool: """Whether the explicit mode is enabled.""" return self["entries"] is not None or self["exits"] is not None @cachedproperty def explicit_mode(self) -> bool: """Whether the explicit mode is enabled.""" return self["long_entries"] is not None or self["long_exits"] is not None @cachedproperty def _pre_ls_mode(self) -> bool: """Whether direction-aware mode is enabled before resolution.""" return self.explicit_mode or self["short_entries"] is not None or self["short_exits"] is not None @cachedproperty def _pre_signals_mode(self) -> bool: """Whether signals mode is enabled before resolution.""" return self.implicit_mode or self._pre_ls_mode @cachedproperty def ls_mode(self) -> bool: """Whether direction-aware mode is enabled.""" if not self._pre_signals_mode and not self.order_mode and self["signal_func_nb"] is None: return True ls_mode = self._pre_ls_mode if self.config.get("direction", None) is not None and ls_mode: raise ValueError("Direction and short signal arrays cannot be used together") return ls_mode @cachedproperty def signals_mode(self) -> bool: """Whether signals mode is enabled.""" if not self._pre_signals_mode and not self.order_mode and self["signal_func_nb"] is None: return True signals_mode = self._pre_signals_mode if signals_mode and self.order_mode: raise ValueError("Signal arrays and order mode cannot be used together") return signals_mode @cachedproperty def signal_func_mode(self) -> bool: """Whether signal function mode is enabled.""" return self.dynamic_mode and not self.signals_mode and not self.order_mode @cachedproperty def adjust_func_nb(self) -> tp.Optional[tp.Callable]: """Argument `adjust_func_nb`.""" if self.dynamic_mode: if self["adjust_func_nb"] is None: return nb.no_adjust_func_nb return self["adjust_func_nb"] return None @cachedproperty def signal_func_nb(self) -> tp.Optional[tp.Callable]: """Argument `signal_func_nb`.""" if self.dynamic_mode: if self["signal_func_nb"] is None: if self.ls_mode: return nb.ls_signal_func_nb if self.signals_mode: return nb.dir_signal_func_nb if self.order_mode: return nb.order_signal_func_nb return None return self["signal_func_nb"] return None @cachedproperty def post_signal_func_nb(self) -> tp.Optional[tp.Callable]: """Argument `post_signal_func_nb`.""" if self.dynamic_mode: if self["post_signal_func_nb"] is None: return nb.no_post_func_nb return self["post_signal_func_nb"] return None @cachedproperty def post_segment_func_nb(self) -> tp.Optional[tp.Callable]: """Argument `post_segment_func_nb`.""" if self.dynamic_mode: if self["post_segment_func_nb"] is None: if self.save_state or self.save_value or self.save_returns: return nb.save_post_segment_func_nb return nb.no_post_func_nb return self["post_segment_func_nb"] return None @cachedproperty def staticized(self) -> tp.StaticizedOption: """Argument `staticized`.""" staticized = self._pre_staticized if isinstance(staticized, dict): staticized = dict(staticized) if self.dynamic_mode: if self["signal_func_nb"] is None: if self.ls_mode: self.adapt_staticized_to_udf(staticized, "ls_signal_func_nb", "signal_func_nb") staticized["suggest_fname"] = "from_ls_signal_func_nb" elif self.signals_mode: self.adapt_staticized_to_udf(staticized, "dir_signal_func_nb", "signal_func_nb") staticized["suggest_fname"] = "from_dir_signal_func_nb" elif self.order_mode: self.adapt_staticized_to_udf(staticized, "order_signal_func_nb", "signal_func_nb") staticized["suggest_fname"] = "from_order_signal_func_nb" else: self.adapt_staticized_to_udf(staticized, self["signal_func_nb"], "signal_func_nb") if self["adjust_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["adjust_func_nb"], "adjust_func_nb") if self["post_signal_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_signal_func_nb"], "post_signal_func_nb") if self["post_segment_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_segment_func_nb"], "post_segment_func_nb") elif self.save_state or self.save_value or self.save_returns: self.adapt_staticized_to_udf(staticized, "save_post_segment_func_nb", "post_segment_func_nb") return staticized @cachedproperty def _pre_chunked(self) -> tp.ChunkedOption: """Argument `chunked` before template substitution.""" return self["chunked"] # ############# Before broadcasting ############# # @cachedproperty def _pre_entries(self) -> tp.ArrayLike: """Argument `entries` before broadcasting.""" return self["entries"] if self["entries"] is not None else False @cachedproperty def _pre_exits(self) -> tp.ArrayLike: """Argument `exits` before broadcasting.""" return self["exits"] if self["exits"] is not None else False @cachedproperty def _pre_long_entries(self) -> tp.ArrayLike: """Argument `long_entries` before broadcasting.""" return self["long_entries"] if self["long_entries"] is not None else False @cachedproperty def _pre_long_exits(self) -> tp.ArrayLike: """Argument `long_exits` before broadcasting.""" return self["long_exits"] if self["long_exits"] is not None else False @cachedproperty def _pre_short_entries(self) -> tp.ArrayLike: """Argument `short_entries` before broadcasting.""" return self["short_entries"] if self["short_entries"] is not None else False @cachedproperty def _pre_short_exits(self) -> tp.ArrayLike: """Argument `short_exits` before broadcasting.""" return self["short_exits"] if self["short_exits"] is not None else False @cachedproperty def _pre_from_ago(self) -> tp.ArrayLike: """Argument `from_ago` before broadcasting.""" from_ago = self["from_ago"] if from_ago is not None: return from_ago return 0 @cachedproperty def _pre_max_log_records(self) -> tp.Optional[int]: """Argument `max_log_records` before broadcasting.""" return self["max_log_records"] @classmethod def init_in_outputs( cls, wrapper: ArrayWrapper, group_lens: tp.Optional[tp.GroupLens] = None, cash_sharing: bool = False, save_state: bool = True, save_value: bool = True, save_returns: bool = True, ) -> enums.FSInOutputs: """Initialize `vectorbtpro.portfolio.enums.FSInOutputs`.""" if cash_sharing: if group_lens is None: group_lens = wrapper.grouper.get_group_lens() return nb.init_FSInOutputs_nb( wrapper.shape_2d, group_lens, cash_sharing=cash_sharing, save_state=save_state, save_value=save_value, save_returns=save_returns, ) @cachedproperty def _pre_in_outputs(self) -> tp.Union[None, tp.NamedTuple, CustomTemplate]: if self.dynamic_mode: if self["post_segment_func_nb"] is None: if self.save_state or self.save_value or self.save_returns: return RepFunc(self.init_in_outputs) return BasePFPreparer._pre_in_outputs.func(self) if self["in_outputs"] is not None: raise ValueError("Argument in_outputs cannot be used in fixed mode") return None # ############# Broadcasting ############# # @cachedproperty def def_broadcast_kwargs(self) -> tp.Kwargs: def_broadcast_kwargs = dict(BasePFPreparer.def_broadcast_kwargs.func(self)) new_def_broadcast_kwargs = dict() if self.order_mode: new_def_broadcast_kwargs["keep_flex"] = dict( size=False, size_type=False, min_size=False, max_size=False, ) new_def_broadcast_kwargs["min_ndim"] = dict( size=2, size_type=2, min_size=2, max_size=2, ) new_def_broadcast_kwargs["require_kwargs"] = dict( size=dict(requirements="O"), size_type=dict(requirements="O"), min_size=dict(requirements="O"), max_size=dict(requirements="O"), ) if self.stop_ladder: new_def_broadcast_kwargs["axis"] = dict( sl_stop=1, tsl_stop=1, tp_stop=1, td_stop=1, dt_stop=1, ) new_def_broadcast_kwargs["merge_kwargs"] = dict( sl_stop=dict(reset_index="from_start", fill_value=np.nan), tsl_stop=dict(reset_index="from_start", fill_value=np.nan), tp_stop=dict(reset_index="from_start", fill_value=np.nan), td_stop=dict(reset_index="from_start", fill_value=-1), dt_stop=dict(reset_index="from_start", fill_value=-1), ) return merge_dicts(def_broadcast_kwargs, new_def_broadcast_kwargs) # ############# After broadcasting ############# # @cachedproperty def sim_group_lens(self) -> tp.GroupLens: if not self.dynamic_mode: return self.cs_group_lens return self.group_lens @cachedproperty def signals(self) -> tp.Tuple[tp.ArrayLike, tp.ArrayLike, tp.ArrayLike, tp.ArrayLike]: """Arguments `entries`, `exits`, `short_entries`, and `short_exits` after broadcasting.""" if not self.dynamic_mode and not self.ls_mode: entries = self._post_entries exits = self._post_exits direction = self._post_direction if direction.size == 1: _direction = direction.item(0) if _direction == enums.Direction.LongOnly: long_entries = entries long_exits = exits short_entries = np.array([[False]]) short_exits = np.array([[False]]) elif _direction == enums.Direction.ShortOnly: long_entries = np.array([[False]]) long_exits = np.array([[False]]) short_entries = entries short_exits = exits else: long_entries = entries long_exits = np.array([[False]]) short_entries = exits short_exits = np.array([[False]]) else: return nb.dir_to_ls_signals_nb( target_shape=self.target_shape, entries=entries, exits=exits, direction=direction, ) else: if self.explicit_mode and self.implicit_mode: long_entries = self._post_entries | self._post_long_entries long_exits = self._post_exits | self._post_long_exits short_entries = self._post_entries | self._post_short_entries short_exits = self._post_exits | self._post_short_exits elif self.explicit_mode: long_entries = self._post_long_entries long_exits = self._post_long_exits short_entries = self._post_short_entries short_exits = self._post_short_exits else: long_entries = self._post_entries long_exits = self._post_exits short_entries = self._post_short_entries short_exits = self._post_short_exits return long_entries, long_exits, short_entries, short_exits @cachedproperty def long_entries(self) -> tp.ArrayLike: """Argument `long_entries`.""" return self.signals[0] @cachedproperty def long_exits(self) -> tp.ArrayLike: """Argument `long_exits`.""" return self.signals[1] @cachedproperty def short_entries(self) -> tp.ArrayLike: """Argument `short_entries`.""" return self.signals[2] @cachedproperty def short_exits(self) -> tp.ArrayLike: """Argument `short_exits`.""" return self.signals[3] @cachedproperty def combined_mask(self) -> tp.Array2d: """Signals combined using the OR rule into a mask.""" long_entries = to_2d_array(self.long_entries) long_exits = to_2d_array(self.long_exits) short_entries = to_2d_array(self.short_entries) short_exits = to_2d_array(self.short_exits) return long_entries | long_exits | short_entries | short_exits @cachedproperty def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]: if self.combined_mask.shape[0] == 1: return None first_signal_idx = signals_nb.nth_index_nb(self.combined_mask, 0) return np.where(first_signal_idx == -1, 0, first_signal_idx) @cachedproperty def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]: if self.combined_mask.shape[0] == 1: return None last_signal_idx = signals_nb.nth_index_nb(self.combined_mask, -1) return np.where(last_signal_idx == -1, len(self.wrapper.index), last_signal_idx + 1) @cachedproperty def price_and_from_ago(self) -> tp.Tuple[tp.ArrayLike, tp.ArrayLike]: """Arguments `price` and `from_ago` after broadcasting.""" price = self._post_price from_ago = self._post_from_ago if self["from_ago"] is None: if price.size == 1 or price.shape[0] == 1: next_open_mask = price == enums.PriceType.NextOpen next_close_mask = price == enums.PriceType.NextClose if next_open_mask.any() or next_close_mask.any(): price = price.astype(float_) price[next_open_mask] = enums.PriceType.Open price[next_close_mask] = enums.PriceType.Close from_ago = np.full(price.shape, 0, dtype=int_) from_ago[next_open_mask] = 1 from_ago[next_close_mask] = 1 return price, from_ago @cachedproperty def price(self) -> tp.ArrayLike: """Argument `price`.""" return self.price_and_from_ago[0] @cachedproperty def from_ago(self) -> tp.ArrayLike: """Argument `from_ago`.""" return self.price_and_from_ago[1] @cachedproperty def max_log_records(self) -> tp.Optional[int]: """Argument `max_log_records`.""" max_log_records = self._pre_max_log_records if max_log_records is None: _log = self._post_log if _log.size == 1: max_log_records = self.target_shape[0] * int(_log.item(0)) else: if _log.shape[0] == 1 and self.target_shape[0] > 1: max_log_records = self.target_shape[0] * int(np.any(_log)) else: max_log_records = int(np.max(np.sum(_log, axis=0))) return max_log_records @cachedproperty def use_stops(self) -> bool: """Argument `use_stops`.""" if self["use_stops"] is None: if self.stop_ladder: use_stops = True else: if self.dynamic_mode: use_stops = True else: if ( not np.all(np.isnan(self.sl_stop)) or not np.all(np.isnan(self.tsl_stop)) or not np.all(np.isnan(self.tp_stop)) or np.any(self.td_stop != -1) or np.any(self.dt_stop != -1) ): use_stops = True else: use_stops = False else: use_stops = self["use_stops"] return use_stops @cachedproperty def use_limit_orders(self) -> bool: """Whether to use limit orders.""" if np.any(self.order_type == enums.OrderType.Limit): return True if self.use_stops and np.any(self.stop_order_type == enums.OrderType.Limit): return True return False @cachedproperty def basic_mode(self) -> bool: """Whether the basic mode is enabled.""" return not self.use_stops and not self.use_limit_orders # ############# Template substitution ############# # @cachedproperty def template_context(self) -> tp.Kwargs: return merge_dicts( dict( order_mode=self.order_mode, use_stops=self.use_stops, stop_ladder=self.stop_ladder, adjust_func_nb=self.adjust_func_nb, adjust_args=self._pre_adjust_args, signal_func_nb=self.signal_func_nb, signal_args=self._pre_signal_args, post_signal_func_nb=self.post_signal_func_nb, post_signal_args=self._pre_post_signal_args, post_segment_func_nb=self.post_segment_func_nb, post_segment_args=self._pre_post_segment_args, ffill_val_price=self.ffill_val_price, update_value=self.update_value, fill_pos_info=self.fill_pos_info, save_state=self.save_state, save_value=self.save_value, save_returns=self.save_returns, max_order_records=self.max_order_records, max_log_records=self.max_log_records, ), BasePFPreparer.template_context.func(self), ) @cachedproperty def signal_args(self) -> tp.Args: """Argument `signal_args`.""" if self.dynamic_mode: if self["signal_func_nb"] is None: if self.ls_mode: return ( self.long_entries, self.long_exits, self.short_entries, self.short_exits, self.from_ago, *((self.adjust_func_nb,) if self.staticized is None else ()), self.adjust_args, ) if self.signals_mode: return ( self.entries, self.exits, self.direction, self.from_ago, *((self.adjust_func_nb,) if self.staticized is None else ()), self.adjust_args, ) if self.order_mode: return ( self.size, self.price, self.size_type, self.direction, self.min_size, self.max_size, self.val_price, self.from_ago, *((self.adjust_func_nb,) if self.staticized is None else ()), self.adjust_args, ) return self._post_signal_args @cachedproperty def post_segment_args(self) -> tp.Args: """Argument `post_segment_args`.""" if self.dynamic_mode: if self["post_segment_func_nb"] is None: if self.save_state or self.save_value or self.save_returns: return ( self.save_state, self.save_value, self.save_returns, ) return self._post_post_segment_args @cachedproperty def chunked(self) -> tp.ChunkedOption: if self.dynamic_mode: if self["signal_func_nb"] is None: if self.ls_mode: return ch.specialize_chunked_option( self._pre_chunked, arg_take_spec=dict( signal_args=ch.ArgsTaker( base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, *((None,) if self.staticized is None else ()), ch.ArgsTaker(), ) ), ) if self.signals_mode: return ch.specialize_chunked_option( self._pre_chunked, arg_take_spec=dict( signal_args=ch.ArgsTaker( base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, *((None,) if self.staticized is None else ()), ch.ArgsTaker(), ) ), ) if self.order_mode: return ch.specialize_chunked_option( self._pre_chunked, arg_take_spec=dict( signal_args=ch.ArgsTaker( base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, base_ch.flex_array_gl_slicer, *((None,) if self.staticized is None else ()), ch.ArgsTaker(), ) ), ) return self._pre_chunked # ############# Result ############# # @cachedproperty def target_func(self) -> tp.Optional[tp.Callable]: if self.dynamic_mode: func = self.resolve_dynamic_target_func("from_signal_func_nb", self.staticized) elif not self.basic_mode: func = nb.from_signals_nb else: func = nb.from_basic_signals_nb func = jit_reg.resolve_option(func, self.jitted) func = ch_reg.resolve_option(func, self.chunked) return func @cachedproperty def target_arg_map(self) -> tp.Kwargs: target_arg_map = dict(BasePFPreparer.target_arg_map.func(self)) if self.dynamic_mode: if self.staticized is not None: target_arg_map["signal_func_nb"] = None target_arg_map["post_signal_func_nb"] = None target_arg_map["post_segment_func_nb"] = None else: target_arg_map["group_lens"] = "cs_group_lens" return target_arg_map @cachedproperty def pf_args(self) -> tp.Optional[tp.Kwargs]: pf_args = dict(BasePFPreparer.pf_args.func(self)) pf_args["orders_cls"] = FSOrders return pf_args FSPreparer.override_arg_config_doc(__pdoc__) fof_arg_config = ReadonlyConfig( dict( segment_mask=dict(), call_pre_segment=dict(), call_post_segment=dict(), pre_sim_func_nb=dict(), pre_sim_args=dict(type="args", substitute_templates=True), post_sim_func_nb=dict(), post_sim_args=dict(type="args", substitute_templates=True), pre_group_func_nb=dict(), pre_group_args=dict(type="args", substitute_templates=True), post_group_func_nb=dict(), post_group_args=dict(type="args", substitute_templates=True), pre_row_func_nb=dict(), pre_row_args=dict(type="args", substitute_templates=True), post_row_func_nb=dict(), post_row_args=dict(type="args", substitute_templates=True), pre_segment_func_nb=dict(), pre_segment_args=dict(type="args", substitute_templates=True), post_segment_func_nb=dict(), post_segment_args=dict(type="args", substitute_templates=True), order_func_nb=dict(), order_args=dict(type="args", substitute_templates=True), flex_order_func_nb=dict(), flex_order_args=dict(type="args", substitute_templates=True), post_order_func_nb=dict(), post_order_args=dict(type="args", substitute_templates=True), ffill_val_price=dict(), update_value=dict(), fill_pos_info=dict(), track_value=dict(), row_wise=dict(), max_order_records=dict(), max_log_records=dict(), ) ) """_""" __pdoc__[ "fof_arg_config" ] = f"""Argument config for `FOFPreparer`. ```python {fof_arg_config.prettify()} ``` """ @attach_arg_properties @override_arg_config(fof_arg_config) class FOFPreparer(BasePFPreparer): """Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_order_func`.""" _settings_path: tp.SettingsPath = "portfolio.from_order_func" # ############# Mode resolution ############# # @cachedproperty def _pre_staticized(self) -> tp.StaticizedOption: """Argument `staticized` before its resolution.""" staticized = self["staticized"] if isinstance(staticized, bool): if staticized: staticized = dict() else: staticized = None if isinstance(staticized, dict): staticized = dict(staticized) if "func" not in staticized: if not self.flexible and not self.row_wise: staticized["func"] = nb.from_order_func_nb elif not self.flexible and self.row_wise: staticized["func"] = nb.from_order_func_rw_nb elif self.flexible and not self.row_wise: staticized["func"] = nb.from_flex_order_func_nb else: staticized["func"] = nb.from_flex_order_func_rw_nb return staticized @cachedproperty def flexible(self) -> bool: """Whether the flexible mode is enabled.""" return self["flex_order_func_nb"] is not None @cachedproperty def pre_sim_func_nb(self) -> tp.Callable: """Argument `pre_sim_func_nb`.""" pre_sim_func_nb = self["pre_sim_func_nb"] if pre_sim_func_nb is None: pre_sim_func_nb = nb.no_pre_func_nb return pre_sim_func_nb @cachedproperty def post_sim_func_nb(self) -> tp.Callable: """Argument `post_sim_func_nb`.""" post_sim_func_nb = self["post_sim_func_nb"] if post_sim_func_nb is None: post_sim_func_nb = nb.no_post_func_nb return post_sim_func_nb @cachedproperty def pre_group_func_nb(self) -> tp.Callable: """Argument `pre_group_func_nb`.""" pre_group_func_nb = self["pre_group_func_nb"] if self.row_wise and pre_group_func_nb is not None: raise ValueError("Cannot use pre_group_func_nb in a row-wise simulation") if pre_group_func_nb is None: pre_group_func_nb = nb.no_pre_func_nb return pre_group_func_nb @cachedproperty def post_group_func_nb(self) -> tp.Callable: """Argument `post_group_func_nb`.""" post_group_func_nb = self["post_group_func_nb"] if self.row_wise and post_group_func_nb is not None: raise ValueError("Cannot use post_group_func_nb in a row-wise simulation") if post_group_func_nb is None: post_group_func_nb = nb.no_post_func_nb return post_group_func_nb @cachedproperty def pre_row_func_nb(self) -> tp.Callable: """Argument `pre_row_func_nb`.""" pre_row_func_nb = self["pre_row_func_nb"] if not self.row_wise and pre_row_func_nb is not None: raise ValueError("Cannot use pre_row_func_nb in a column-wise simulation") if pre_row_func_nb is None: pre_row_func_nb = nb.no_pre_func_nb return pre_row_func_nb @cachedproperty def post_row_func_nb(self) -> tp.Callable: """Argument `post_row_func_nb`.""" post_row_func_nb = self["post_row_func_nb"] if not self.row_wise and post_row_func_nb is not None: raise ValueError("Cannot use post_row_func_nb in a column-wise simulation") if post_row_func_nb is None: post_row_func_nb = nb.no_post_func_nb return post_row_func_nb @cachedproperty def pre_segment_func_nb(self) -> tp.Callable: """Argument `pre_segment_func_nb`.""" pre_segment_func_nb = self["pre_segment_func_nb"] if pre_segment_func_nb is None: pre_segment_func_nb = nb.no_pre_func_nb return pre_segment_func_nb @cachedproperty def post_segment_func_nb(self) -> tp.Callable: """Argument `post_segment_func_nb`.""" post_segment_func_nb = self["post_segment_func_nb"] if post_segment_func_nb is None: post_segment_func_nb = nb.no_post_func_nb return post_segment_func_nb @cachedproperty def order_func_nb(self) -> tp.Callable: """Argument `order_func_nb`.""" order_func_nb = self["order_func_nb"] if self.flexible and order_func_nb is not None: raise ValueError("Must provide either order_func_nb or flex_order_func_nb") if not self.flexible and order_func_nb is None: raise ValueError("Must provide either order_func_nb or flex_order_func_nb") if order_func_nb is None: order_func_nb = nb.no_order_func_nb return order_func_nb @cachedproperty def flex_order_func_nb(self) -> tp.Callable: """Argument `flex_order_func_nb`.""" flex_order_func_nb = self["flex_order_func_nb"] if flex_order_func_nb is None: flex_order_func_nb = nb.no_flex_order_func_nb return flex_order_func_nb @cachedproperty def post_order_func_nb(self) -> tp.Callable: """Argument `post_order_func_nb`.""" post_order_func_nb = self["post_order_func_nb"] if post_order_func_nb is None: post_order_func_nb = nb.no_post_func_nb return post_order_func_nb @cachedproperty def staticized(self) -> tp.StaticizedOption: """Argument `staticized`.""" staticized = self._pre_staticized if isinstance(staticized, dict): staticized = dict(staticized) if self["pre_sim_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["pre_sim_func_nb"], "pre_sim_func_nb") if self["post_sim_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_sim_func_nb"], "post_sim_func_nb") if self["pre_group_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["pre_group_func_nb"], "pre_group_func_nb") if self["post_group_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_group_func_nb"], "post_group_func_nb") if self["pre_row_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["pre_row_func_nb"], "pre_row_func_nb") if self["post_row_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_row_func_nb"], "post_row_func_nb") if self["pre_segment_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["pre_segment_func_nb"], "pre_segment_func_nb") if self["post_segment_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_segment_func_nb"], "post_segment_func_nb") if self["order_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["order_func_nb"], "order_func_nb") if self["flex_order_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["flex_order_func_nb"], "flex_order_func_nb") if self["post_order_func_nb"] is not None: self.adapt_staticized_to_udf(staticized, self["post_order_func_nb"], "post_order_func_nb") return staticized # ############# Before broadcasting ############# # @cachedproperty def _pre_call_seq(self) -> tp.Optional[tp.ArrayLike]: if self.auto_call_seq: raise ValueError( "CallSeqType.Auto must be implemented manually. Use sort_call_seq_nb in pre_segment_func_nb." ) return self["call_seq"] @cachedproperty def _pre_segment_mask(self) -> tp.ArrayLike: """Argument `segment_mask` before broadcasting.""" return self["segment_mask"] # ############# After broadcasting ############# # @cachedproperty def sim_start(self) -> tp.Optional[tp.ArrayLike]: sim_start = self["sim_start"] if sim_start is None: return None return BasePFPreparer.sim_start.func(self) @cachedproperty def sim_end(self) -> tp.Optional[tp.ArrayLike]: sim_end = self["sim_end"] if sim_end is None: return None return BasePFPreparer.sim_end.func(self) @cachedproperty def segment_mask(self) -> tp.ArrayLike: """Argument `segment_mask`.""" segment_mask = self._pre_segment_mask if checks.is_int(segment_mask): if self.keep_inout_flex: _segment_mask = np.full((self.target_shape[0], 1), False) else: _segment_mask = np.full((self.target_shape[0], len(self.group_lens)), False) _segment_mask[0::segment_mask] = True segment_mask = _segment_mask else: segment_mask = broadcast( segment_mask, to_shape=(self.target_shape[0], len(self.group_lens)), to_pd=False, keep_flex=self.keep_inout_flex, reindex_kwargs=dict(fill_value=False), require_kwargs=self.broadcast_kwargs.get("require_kwargs", {}), ) checks.assert_subdtype(segment_mask, np.bool_, arg_name="segment_mask") return segment_mask # ############# Template substitution ############# # @cachedproperty def template_context(self) -> tp.Kwargs: return merge_dicts( dict( segment_mask=self.segment_mask, call_pre_segment=self.call_pre_segment, call_post_segment=self.call_post_segment, pre_sim_func_nb=self.pre_sim_func_nb, pre_sim_args=self._pre_pre_sim_args, post_sim_func_nb=self.post_sim_func_nb, post_sim_args=self._pre_post_sim_args, pre_group_func_nb=self.pre_group_func_nb, pre_group_args=self._pre_pre_group_args, post_group_func_nb=self.post_group_func_nb, post_group_args=self._pre_post_group_args, pre_row_func_nb=self.pre_row_func_nb, pre_row_args=self._pre_pre_row_args, post_row_func_nb=self.post_row_func_nb, post_row_args=self._pre_post_row_args, pre_segment_func_nb=self.pre_segment_func_nb, pre_segment_args=self._pre_pre_segment_args, post_segment_func_nb=self.post_segment_func_nb, post_segment_args=self._pre_post_segment_args, order_func_nb=self.order_func_nb, order_args=self._pre_order_args, flex_order_func_nb=self.flex_order_func_nb, flex_order_args=self._pre_flex_order_args, post_order_func_nb=self.post_order_func_nb, post_order_args=self._pre_post_order_args, ffill_val_price=self.ffill_val_price, update_value=self.update_value, fill_pos_info=self.fill_pos_info, track_value=self.track_value, max_order_records=self.max_order_records, max_log_records=self.max_log_records, ), BasePFPreparer.template_context.func(self), ) # ############# Result ############# # @cachedproperty def target_func(self) -> tp.Optional[tp.Callable]: if not self.row_wise and not self.flexible: func = self.resolve_dynamic_target_func("from_order_func_nb", self.staticized) elif not self.row_wise and self.flexible: func = self.resolve_dynamic_target_func("from_flex_order_func_nb", self.staticized) elif self.row_wise and not self.flexible: func = self.resolve_dynamic_target_func("from_order_func_rw_nb", self.staticized) else: func = self.resolve_dynamic_target_func("from_flex_order_func_rw_nb", self.staticized) func = jit_reg.resolve_option(func, self.jitted) func = ch_reg.resolve_option(func, self.chunked) return func @cachedproperty def target_arg_map(self) -> tp.Kwargs: target_arg_map = dict(BasePFPreparer.target_arg_map.func(self)) if self.staticized is not None: target_arg_map["pre_sim_func_nb"] = None target_arg_map["post_sim_func_nb"] = None target_arg_map["pre_group_func_nb"] = None target_arg_map["post_group_func_nb"] = None target_arg_map["pre_row_func_nb"] = None target_arg_map["post_row_func_nb"] = None target_arg_map["pre_segment_func_nb"] = None target_arg_map["post_segment_func_nb"] = None target_arg_map["order_func_nb"] = None target_arg_map["flex_order_func_nb"] = None target_arg_map["post_order_func_nb"] = None return target_arg_map fdof_arg_config = ReadonlyConfig( dict( val_price=dict( broadcast=True, map_enum_kwargs=dict(enum=enums.ValPriceType, ignore_type=(int, float)), subdtype=np.number, broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)), ), flexible=dict(), ) ) """_""" __pdoc__[ "fdof_arg_config" ] = f"""Argument config for `FDOFPreparer`. ```python {fdof_arg_config.prettify()} ``` """ @attach_arg_properties @override_arg_config(fdof_arg_config) @override_arg_config(order_arg_config) class FDOFPreparer(FOFPreparer): """Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_def_order_func`.""" _settings_path: tp.SettingsPath = "portfolio.from_def_order_func" # ############# Mode resolution ############# # @cachedproperty def flexible(self) -> bool: return self["flexible"] @cachedproperty def pre_segment_func_nb(self) -> tp.Callable: """Argument `pre_segment_func_nb`.""" pre_segment_func_nb = self["pre_segment_func_nb"] if pre_segment_func_nb is None: if self.flexible: pre_segment_func_nb = nb.def_flex_pre_segment_func_nb else: pre_segment_func_nb = nb.def_pre_segment_func_nb return pre_segment_func_nb @cachedproperty def order_func_nb(self) -> tp.Callable: """Argument `order_func_nb`.""" order_func_nb = self["order_func_nb"] if self.flexible and order_func_nb is not None: raise ValueError("Argument order_func_nb cannot be provided when flexible=True") if order_func_nb is None: order_func_nb = nb.def_order_func_nb return order_func_nb @cachedproperty def flex_order_func_nb(self) -> tp.Callable: """Argument `flex_order_func_nb`.""" flex_order_func_nb = self["flex_order_func_nb"] if not self.flexible and flex_order_func_nb is not None: raise ValueError("Argument flex_order_func_nb cannot be provided when flexible=False") if flex_order_func_nb is None: flex_order_func_nb = nb.def_flex_order_func_nb return flex_order_func_nb @cachedproperty def _pre_chunked(self) -> tp.ChunkedOption: """Argument `chunked` before template substitution.""" return self["chunked"] @cachedproperty def staticized(self) -> tp.StaticizedOption: staticized = FOFPreparer.staticized.func(self) if isinstance(staticized, dict): if "pre_segment_func_nb" not in staticized: self.adapt_staticized_to_udf(staticized, self.pre_segment_func_nb, "pre_segment_func_nb") if "order_func_nb" not in staticized: self.adapt_staticized_to_udf(staticized, self.order_func_nb, "order_func_nb") if "flex_order_func_nb" not in staticized: self.adapt_staticized_to_udf(staticized, self.flex_order_func_nb, "flex_order_func_nb") return staticized # ############# Before broadcasting ############# # @cachedproperty def _pre_call_seq(self) -> tp.Optional[tp.ArrayLike]: return BasePFPreparer._pre_call_seq.func(self) # ############# After broadcasting ############# # @cachedproperty def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]: return FOPreparer.auto_sim_start.func(self) @cachedproperty def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]: return FOPreparer.auto_sim_end.func(self) # ############# Template substitution ############# # @cachedproperty def pre_segment_args(self) -> tp.Args: """Argument `pre_segment_args`.""" return ( self.val_price, self.price, self.size, self.size_type, self.direction, self.auto_call_seq, ) @cachedproperty def _order_args(self) -> tp.Args: """Either `order_args` or `flex_order_args`.""" return ( self.size, self.price, self.size_type, self.direction, self.fees, self.fixed_fees, self.slippage, self.min_size, self.max_size, self.size_granularity, self.leverage, self.leverage_mode, self.reject_prob, self.price_area_vio_mode, self.allow_partial, self.raise_reject, self.log, ) @cachedproperty def order_args(self) -> tp.Args: """Argument `order_args`.""" if self.flexible: return self._post_order_args return self._order_args @cachedproperty def flex_order_args(self) -> tp.Args: """Argument `flex_order_args`.""" if not self.flexible: return self._post_flex_order_args return self._order_args @cachedproperty def chunked(self) -> tp.ChunkedOption: arg_take_spec = dict() arg_take_spec["pre_segment_args"] = ch.ArgsTaker(*[base_ch.flex_array_gl_slicer] * 5, None) if self.flexible: arg_take_spec["flex_order_args"] = ch.ArgsTaker(*[base_ch.flex_array_gl_slicer] * 17) else: arg_take_spec["order_args"] = ch.ArgsTaker(*[base_ch.flex_array_gl_slicer] * 17) return ch.specialize_chunked_option(self._pre_chunked, arg_take_spec=arg_take_spec) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with trade records. Trade records capture information on trades. In vectorbt, a trade is a sequence of orders that starts with an opening order and optionally ends with a closing order. Every pair of opposite orders can be represented by a trade. Each trade has a PnL info attached to quickly assess its performance. An interesting effect of this representation is the ability to aggregate trades: if two or more trades are happening one after another in time, they can be aggregated into a bigger trade. This way, for example, single-order trades can be aggregated into positions; but also multiple positions can be aggregated into a single blob that reflects the performance of the entire symbol. !!! warning All classes return both closed AND open trades/positions, which may skew your performance results. To only consider closed trades/positions, you should explicitly query the `status_closed` attribute. ## Trade types There are three main types of trades. ### Entry trades An entry trade is created from each order that opens or adds to a position. For example, if we have a single large buy order and 100 smaller sell orders, we will see a single trade with the entry information copied from the buy order and the exit information being a size-weighted average over the exit information of all sell orders. On the other hand, if we have 100 smaller buy orders and a single sell order, we will see 100 trades, each with the entry information copied from the buy order and the exit information being a size-based fraction of the exit information of the sell order. Use `vectorbtpro.portfolio.trades.EntryTrades.from_orders` to build entry trades from orders. Also available as `vectorbtpro.portfolio.base.Portfolio.entry_trades`. ### Exit trades An exit trade is created from each order that closes or removes from a position. Use `vectorbtpro.portfolio.trades.ExitTrades.from_orders` to build exit trades from orders. Also available as `vectorbtpro.portfolio.base.Portfolio.exit_trades`. ### Positions A position is created from a sequence of entry or exit trades. Use `vectorbtpro.portfolio.trades.Positions.from_trades` to build positions from entry or exit trades. Also available as `vectorbtpro.portfolio.base.Portfolio.positions`. ## Example * Increasing position: ```pycon >>> from vectorbtpro import * >>> # Entry trades >>> pf_kwargs = dict( ... close=pd.Series([1., 2., 3., 4., 5.]), ... size=pd.Series([1., 1., 1., 1., -4.]), ... fixed_fees=1. ... ) >>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades >>> entry_trades.readable Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 1 1 0 1.0 1 1 2.0 2 2 0 1.0 2 2 3.0 3 3 0 1.0 3 3 4.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 4 4 5.0 0.25 2.75 1 1.0 4 4 5.0 0.25 1.75 2 1.0 4 4 5.0 0.25 0.75 3 1.0 4 4 5.0 0.25 -0.25 Return Direction Status Position Id 0 2.7500 Long Closed 0 1 0.8750 Long Closed 0 2 0.2500 Long Closed 0 3 -0.0625 Long Closed 0 >>> # Exit trades >>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades >>> exit_trades.readable Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 4.0 0 0 2.5 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 4.0 4 4 5.0 1.0 5.0 Return Direction Status Position Id 0 0.5 Long Closed 0 >>> # Positions >>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions >>> positions.readable Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 4.0 0 0 2.5 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 4.0 4 4 5.0 1.0 5.0 Return Direction Status 0 0.5 Long Closed >>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum() True ``` * Decreasing position: ```pycon >>> # Entry trades >>> pf_kwargs = dict( ... close=pd.Series([1., 2., 3., 4., 5.]), ... size=pd.Series([4., -1., -1., -1., -1.]), ... fixed_fees=1. ... ) >>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades >>> entry_trades.readable Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 4.0 0 0 1.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 4 4 3.5 4.0 5.0 Return Direction Status Position Id 0 1.25 Long Closed 0 >>> # Exit trades >>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades >>> exit_trades.readable Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 1 1 0 1.0 0 0 1.0 2 2 0 1.0 0 0 1.0 3 3 0 1.0 0 0 1.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 0.25 1 1 2.0 1.0 -0.25 1 0.25 2 2 3.0 1.0 0.75 2 0.25 3 3 4.0 1.0 1.75 3 0.25 4 4 5.0 1.0 2.75 Return Direction Status Position Id 0 -0.25 Long Closed 0 1 0.75 Long Closed 0 2 1.75 Long Closed 0 3 2.75 Long Closed 0 >>> # Positions >>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions >>> positions.readable Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 4.0 0 0 1.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 4 4 3.5 4.0 5.0 Return Direction Status 0 1.25 Long Closed >>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum() True ``` * Multiple reversing positions: ```pycon >>> # Entry trades >>> pf_kwargs = dict( ... close=pd.Series([1., 2., 3., 4., 5.]), ... size=pd.Series([1., -2., 2., -2., 1.]), ... fixed_fees=1. ... ) >>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades >>> entry_trades.readable Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 1 1 0 1.0 1 1 2.0 2 2 0 1.0 2 2 3.0 3 3 0 1.0 3 3 4.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 1 1 2.0 0.5 -0.5 1 0.5 2 2 3.0 0.5 -2.0 2 0.5 3 3 4.0 0.5 0.0 3 0.5 4 4 5.0 1.0 -2.5 Return Direction Status Position Id 0 -0.500 Long Closed 0 1 -1.000 Short Closed 1 2 0.000 Long Closed 2 3 -0.625 Short Closed 3 >>> # Exit trades >>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades >>> exit_trades.readable Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 1 1 0 1.0 1 1 2.0 2 2 0 1.0 2 2 3.0 3 3 0 1.0 3 3 4.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 1 1 2.0 0.5 -0.5 1 0.5 2 2 3.0 0.5 -2.0 2 0.5 3 3 4.0 0.5 0.0 3 0.5 4 4 5.0 1.0 -2.5 Return Direction Status Position Id 0 -0.500 Long Closed 0 1 -1.000 Short Closed 1 2 0.000 Long Closed 2 3 -0.625 Short Closed 3 >>> # Positions >>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions >>> positions.readable Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 1 1 0 1.0 1 1 2.0 2 2 0 1.0 2 2 3.0 3 3 0 1.0 3 3 4.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 1 1 2.0 0.5 -0.5 1 0.5 2 2 3.0 0.5 -2.0 2 0.5 3 3 4.0 0.5 0.0 3 0.5 4 4 5.0 1.0 -2.5 Return Direction Status 0 -0.500 Long Closed 1 -1.000 Short Closed 2 0.000 Long Closed 3 -0.625 Short Closed >>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum() True ``` * Open position: ```pycon >>> # Entry trades >>> pf_kwargs = dict( ... close=pd.Series([1., 2., 3., 4., 5.]), ... size=pd.Series([1., 0., 0., 0., 0.]), ... fixed_fees=1. ... ) >>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades >>> entry_trades.readable Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 -1 4 5.0 0.0 3.0 Return Direction Status Position Id 0 3.0 Long Open 0 >>> # Exit trades >>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades >>> exit_trades.readable Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 -1 4 5.0 0.0 3.0 Return Direction Status Position Id 0 3.0 Long Open 0 >>> # Positions >>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions >>> positions.readable Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\ 0 0 0 1.0 0 0 1.0 Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\ 0 1.0 -1 4 5.0 0.0 3.0 Return Direction Status 0 3.0 Long Open >>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum() True ``` Get trade count, trade PnL, and winning trade PnL: ```pycon >>> price = pd.Series([1., 2., 3., 4., 3., 2., 1.]) >>> size = pd.Series([1., -0.5, -0.5, 2., -0.5, -0.5, -0.5]) >>> trades = vbt.Portfolio.from_orders(price, size).trades >>> trades.count() 6 >>> trades.pnl.sum() -3.0 >>> trades.winning.count() 2 >>> trades.winning.pnl.sum() 1.5 ``` Get count and PnL of trades with duration of more than 2 days: ```pycon >>> mask = (trades.records['exit_idx'] - trades.records['entry_idx']) > 2 >>> trades_filtered = trades.apply_mask(mask) >>> trades_filtered.count() 2 >>> trades_filtered.pnl.sum() -3.0 ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Trades.metrics`. ```pycon >>> price = vbt.RandomData.pull( ... ['a', 'b'], ... start=datetime(2020, 1, 1), ... end=datetime(2020, 3, 1), ... seed=vbt.symbol_dict(a=42, b=43) ... ).get() ``` [=100% "100%"]{: .candystripe .candystripe-animate } ```pycon >>> size = pd.DataFrame({ ... 'a': np.random.randint(-1, 2, size=len(price.index)), ... 'b': np.random.randint(-1, 2, size=len(price.index)), ... }, index=price.index, columns=price.columns) >>> pf = vbt.Portfolio.from_orders(price, size, fees=0.01, init_cash="auto") >>> pf.trades['a'].stats() Start 2019-12-31 23:00:00+00:00 End 2020-02-29 23:00:00+00:00 Period 61 days 00:00:00 First Trade Start 2019-12-31 23:00:00+00:00 Last Trade End 2020-02-29 23:00:00+00:00 Coverage 60 days 00:00:00 Overlap Coverage 49 days 00:00:00 Total Records 19.0 Total Long Trades 2.0 Total Short Trades 17.0 Total Closed Trades 18.0 Total Open Trades 1.0 Open Trade PnL 16.063 Win Rate [%] 61.111111 Max Win Streak 11.0 Max Loss Streak 7.0 Best Trade [%] 3.526377 Worst Trade [%] -6.543679 Avg Winning Trade [%] 2.225861 Avg Losing Trade [%] -3.601313 Avg Winning Trade Duration 32 days 19:38:10.909090909 Avg Losing Trade Duration 5 days 00:00:00 Profit Factor 1.022425 Expectancy 0.028157 SQN 0.039174 Name: agg_stats, dtype: object ``` Positions share almost identical metrics with trades: ```pycon >>> pf.positions['a'].stats() Start 2019-12-31 23:00:00+00:00 End 2020-02-29 23:00:00+00:00 Period 61 days 00:00:00 First Trade Start 2019-12-31 23:00:00+00:00 Last Trade End 2020-02-29 23:00:00+00:00 Coverage 60 days 00:00:00 Overlap Coverage 0 days 00:00:00 Total Records 5.0 Total Long Trades 2.0 Total Short Trades 3.0 Total Closed Trades 4.0 Total Open Trades 1.0 Open Trade PnL 38.356823 Win Rate [%] 0.0 Max Win Streak 0.0 Max Loss Streak 4.0 Best Trade [%] -1.529613 Worst Trade [%] -6.543679 Avg Winning Trade [%] NaN Avg Losing Trade [%] -3.786739 Avg Winning Trade Duration NaT Avg Losing Trade Duration 4 days 00:00:00 Profit Factor 0.0 Expectancy -5.446748 SQN -1.794214 Name: agg_stats, dtype: object ``` To also include open trades/positions when calculating metrics such as win rate, pass `incl_open=True`: ```pycon >>> pf.trades['a'].stats(settings=dict(incl_open=True)) Start 2019-12-31 23:00:00+00:00 End 2020-02-29 23:00:00+00:00 Period 61 days 00:00:00 First Trade Start 2019-12-31 23:00:00+00:00 Last Trade End 2020-02-29 23:00:00+00:00 Coverage 60 days 00:00:00 Overlap Coverage 49 days 00:00:00 Total Records 19.0 Total Long Trades 2.0 Total Short Trades 17.0 Total Closed Trades 18.0 Total Open Trades 1.0 Open Trade PnL 16.063 Win Rate [%] 61.111111 Max Win Streak 12.0 Max Loss Streak 7.0 Best Trade [%] 3.526377 Worst Trade [%] -6.543679 Avg Winning Trade [%] 2.238896 Avg Losing Trade [%] -3.601313 Avg Winning Trade Duration 33 days 18:00:00 Avg Losing Trade Duration 5 days 00:00:00 Profit Factor 1.733143 Expectancy 0.872096 SQN 0.804714 Name: agg_stats, dtype: object ``` `Trades.stats` also supports (re-)grouping: ```pycon >>> pf.trades.stats(group_by=True) Start 2019-12-31 23:00:00+00:00 End 2020-02-29 23:00:00+00:00 Period 61 days 00:00:00 First Trade Start 2019-12-31 23:00:00+00:00 Last Trade End 2020-02-29 23:00:00+00:00 Coverage 61 days 00:00:00 Overlap Coverage 61 days 00:00:00 Total Records 37 Total Long Trades 5 Total Short Trades 32 Total Closed Trades 35 Total Open Trades 2 Open Trade PnL 1.336259 Win Rate [%] 37.142857 Max Win Streak 11 Max Loss Streak 10 Best Trade [%] 3.526377 Worst Trade [%] -8.710238 Avg Winning Trade [%] 1.907799 Avg Losing Trade [%] -3.259135 Avg Winning Trade Duration 28 days 14:46:09.230769231 Avg Losing Trade Duration 14 days 00:00:00 Profit Factor 0.340493 Expectancy -1.292596 SQN -2.509223 Name: group, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Trades.subplots`. `Trades` class has two subplots based on `Trades.plot` and `Trades.plot_pnl`: ```pycon >>> pf.trades['a'].plots().show() ``` ![](/assets/images/api/trades_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/trades_plots.dark.svg#only-dark){: .iimg loading=lazy } """ import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.indexes import stack_indexes from vectorbtpro.base.reshaping import to_1d_array, to_2d_array, to_pd_array, broadcast_to from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic.enums import range_dt from vectorbtpro.generic.ranges import Ranges from vectorbtpro.portfolio import nb from vectorbtpro.portfolio.enums import TradeDirection, TradeStatus, trade_dt from vectorbtpro.portfolio.orders import Orders from vectorbtpro.records.decorators import attach_fields, override_field_config, attach_shortcut_properties from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils.array_ import min_rel_rescale, max_rel_rescale from vectorbtpro.utils.colors import adjust_lightness from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig from vectorbtpro.utils.template import Rep, RepEval, RepFunc __all__ = [ "Trades", "EntryTrades", "ExitTrades", "Positions", ] __pdoc__ = {} # ############# Trades ############# # trades_field_config = ReadonlyConfig( dict( dtype=trade_dt, settings={ "id": dict(title="Trade Id"), "idx": dict(name="exit_idx"), # remap field of Records "start_idx": dict(name="entry_idx"), # remap field of Ranges "end_idx": dict(name="exit_idx"), # remap field of Ranges "size": dict(title="Size"), "entry_order_id": dict(title="Entry Order Id", mapping="ids"), "entry_idx": dict(title="Entry Index", mapping="index"), "entry_price": dict(title="Avg Entry Price"), "entry_fees": dict(title="Entry Fees"), "exit_order_id": dict(title="Exit Order Id", mapping="ids"), "exit_idx": dict(title="Exit Index", mapping="index"), "exit_price": dict(title="Avg Exit Price"), "exit_fees": dict(title="Exit Fees"), "pnl": dict(title="PnL"), "return": dict(title="Return", hovertemplate="$title: %{customdata[$index]:,%}"), "direction": dict(title="Direction", mapping=TradeDirection), "status": dict(title="Status", mapping=TradeStatus), "parent_id": dict(title="Position Id", mapping="ids"), }, ) ) """_""" __pdoc__[ "trades_field_config" ] = f"""Field config for `Trades`. ```python {trades_field_config.prettify()} ``` """ trades_attach_field_config = ReadonlyConfig( { "return": dict(attach="returns"), "direction": dict(attach_filters=True), "status": dict(attach_filters=True, on_conflict="ignore"), } ) """_""" __pdoc__[ "trades_attach_field_config" ] = f"""Config of fields to be attached to `Trades`. ```python {trades_attach_field_config.prettify()} ``` """ trades_shortcut_config = ReadonlyConfig( dict( ranges=dict(), long_view=dict(), short_view=dict(), winning=dict(), losing=dict(), winning_streak=dict(obj_type="mapped_array"), losing_streak=dict(obj_type="mapped_array"), win_rate=dict(obj_type="red_array"), profit_factor=dict(obj_type="red_array", method_kwargs=dict(use_returns=False)), rel_profit_factor=dict( obj_type="red_array", method_name="get_profit_factor", method_kwargs=dict(use_returns=True, wrap_kwargs=dict(name_or_index="rel_profit_factor")), ), expectancy=dict(obj_type="red_array", method_kwargs=dict(use_returns=False)), rel_expectancy=dict( obj_type="red_array", method_name="get_expectancy", method_kwargs=dict(use_returns=True, wrap_kwargs=dict(name_or_index="rel_expectancy")), ), sqn=dict(obj_type="red_array", method_kwargs=dict(use_returns=False)), rel_sqn=dict( obj_type="red_array", method_name="get_sqn", method_kwargs=dict(use_returns=True, wrap_kwargs=dict(name_or_index="rel_sqn")), ), best_price=dict(obj_type="mapped_array"), worst_price=dict(obj_type="mapped_array"), best_price_idx=dict(obj_type="mapped_array"), worst_price_idx=dict(obj_type="mapped_array"), expanding_best_price=dict(obj_type="array"), expanding_worst_price=dict(obj_type="array"), mfe=dict(obj_type="mapped_array"), mfe_returns=dict( obj_type="mapped_array", method_name="get_mfe", method_kwargs=dict(use_returns=True), ), mae=dict(obj_type="mapped_array"), mae_returns=dict( obj_type="mapped_array", method_name="get_mae", method_kwargs=dict(use_returns=True), ), expanding_mfe=dict(obj_type="array"), expanding_mfe_returns=dict( obj_type="array", method_name="get_expanding_mfe", method_kwargs=dict(use_returns=True), ), expanding_mae=dict(obj_type="array"), expanding_mae_returns=dict( obj_type="array", method_name="get_expanding_mae", method_kwargs=dict(use_returns=True), ), edge_ratio=dict(obj_type="red_array"), running_edge_ratio=dict(obj_type="array"), ) ) """_""" __pdoc__[ "trades_shortcut_config" ] = f"""Config of shortcut properties to be attached to `Trades`. ```python {trades_shortcut_config.prettify()} ``` """ TradesT = tp.TypeVar("TradesT", bound="Trades") @attach_shortcut_properties(trades_shortcut_config) @attach_fields(trades_attach_field_config) @override_field_config(trades_field_config) class Trades(Ranges): """Extends `vectorbtpro.generic.ranges.Ranges` for working with trade-like records, such as entry trades, exit trades, and positions.""" @property def field_config(self) -> Config: return self._field_config def get_ranges(self, **kwargs) -> Ranges: """Get records of type `vectorbtpro.generic.ranges.Ranges`.""" new_records_arr = np.empty(self.values.shape, dtype=range_dt) new_records_arr["id"][:] = self.get_field_arr("id").copy() new_records_arr["col"][:] = self.get_field_arr("col").copy() new_records_arr["start_idx"][:] = self.get_field_arr("entry_idx").copy() new_records_arr["end_idx"][:] = self.get_field_arr("exit_idx").copy() new_records_arr["status"][:] = self.get_field_arr("status").copy() return Ranges.from_records( self.wrapper, new_records_arr, open=self._open, high=self._high, low=self._low, close=self._close, **kwargs, ) # ############# Views ############# # def get_long_view(self: TradesT, **kwargs) -> TradesT: """Get long view.""" filter_mask = self.get_field_arr("direction") == TradeDirection.Long return self.apply_mask(filter_mask, **kwargs) def get_short_view(self: TradesT, **kwargs) -> TradesT: """Get short view.""" filter_mask = self.get_field_arr("direction") == TradeDirection.Short return self.apply_mask(filter_mask, **kwargs) # ############# Stats ############# # def get_winning(self: TradesT, **kwargs) -> TradesT: """Get winning trades.""" filter_mask = self.get_field_arr("pnl") > 0.0 return self.apply_mask(filter_mask, **kwargs) def get_losing(self: TradesT, **kwargs) -> TradesT: """Get losing trades.""" filter_mask = self.get_field_arr("pnl") < 0.0 return self.apply_mask(filter_mask, **kwargs) def get_winning_streak(self, **kwargs) -> MappedArray: """Get winning streak at each trade in the current column. See `vectorbtpro.portfolio.nb.records.trade_winning_streak_nb`.""" return self.apply(nb.trade_winning_streak_nb, dtype=int_, **kwargs) def get_losing_streak(self, **kwargs) -> MappedArray: """Get losing streak at each trade in the current column. See `vectorbtpro.portfolio.nb.records.trade_losing_streak_nb`.""" return self.apply(nb.trade_losing_streak_nb, dtype=int_, **kwargs) def get_win_rate( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get rate of winning trades.""" wrap_kwargs = merge_dicts(dict(name_or_index="win_rate"), wrap_kwargs) return self.get_map_field("pnl").reduce( nb.win_rate_reduce_nb, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_profit_factor( self, use_returns: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get profit factor.""" wrap_kwargs = merge_dicts(dict(name_or_index="profit_factor"), wrap_kwargs) if use_returns: mapped_arr = self.get_map_field("return") else: mapped_arr = self.get_map_field("pnl") return mapped_arr.reduce( nb.profit_factor_reduce_nb, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_expectancy( self, use_returns: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get average profitability.""" wrap_kwargs = merge_dicts(dict(name_or_index="expectancy"), wrap_kwargs) if use_returns: mapped_arr = self.get_map_field("return") else: mapped_arr = self.get_map_field("pnl") return mapped_arr.reduce( nb.expectancy_reduce_nb, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_sqn( self, ddof: int = 1, use_returns: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Get System Quality Number (SQN).""" wrap_kwargs = merge_dicts(dict(name_or_index="sqn"), wrap_kwargs) if use_returns: mapped_arr = self.get_map_field("return") else: mapped_arr = self.get_map_field("pnl") return mapped_arr.reduce( nb.sqn_reduce_nb, ddof, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def get_best_price( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, **kwargs, ) -> MappedArray: """Get best price. See `vectorbtpro.portfolio.nb.records.best_price_nb`.""" return self.apply( nb.best_price_nb, self._open, self._high, self._low, self._close, entry_price_open, exit_price_close, max_duration, **kwargs, ) def get_worst_price( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, **kwargs, ) -> MappedArray: """Get worst price. See `vectorbtpro.portfolio.nb.records.worst_price_nb`.""" return self.apply( nb.worst_price_nb, self._open, self._high, self._low, self._close, entry_price_open, exit_price_close, max_duration, **kwargs, ) def get_best_price_idx( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, relative: bool = True, **kwargs, ) -> MappedArray: """Get (relative) index of best price. See `vectorbtpro.portfolio.nb.records.best_price_idx_nb`.""" return self.apply( nb.best_price_idx_nb, self._open, self._high, self._low, self._close, entry_price_open, exit_price_close, max_duration, relative, dtype=int_, **kwargs, ) def get_worst_price_idx( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, relative: bool = True, **kwargs, ) -> MappedArray: """Get (relative) index of worst price. See `vectorbtpro.portfolio.nb.records.worst_price_idx_nb`.""" return self.apply( nb.worst_price_idx_nb, self._open, self._high, self._low, self._close, entry_price_open, exit_price_close, max_duration, relative, dtype=int_, **kwargs, ) def get_expanding_best_price( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, jitted: tp.JittedOption = None, clean_index_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get expanding best price. See `vectorbtpro.portfolio.nb.records.expanding_best_price_nb`.""" func = jit_reg.resolve_option(nb.expanding_best_price_nb, jitted) out = func( self.values, self._open, self._high, self._low, self._close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, ) if clean_index_kwargs is None: clean_index_kwargs = {} new_columns = stack_indexes( ( self.wrapper.columns[self.get_field_arr("col")], pd.Index(self.get_field_arr("id"), name="id"), ), **clean_index_kwargs, ) if wrap_kwargs is None: wrap_kwargs = {} return self.wrapper.wrap( out, group_by=False, index=pd.RangeIndex(stop=len(out)), columns=new_columns, **wrap_kwargs ) def get_expanding_worst_price( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, jitted: tp.JittedOption = None, clean_index_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get expanding worst price. See `vectorbtpro.portfolio.nb.records.expanding_worst_price_nb`.""" func = jit_reg.resolve_option(nb.expanding_worst_price_nb, jitted) out = func( self.values, self._open, self._high, self._low, self._close, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, ) if clean_index_kwargs is None: clean_index_kwargs = {} new_columns = stack_indexes( ( self.wrapper.columns[self.get_field_arr("col")], pd.Index(self.get_field_arr("id"), name="id"), ), **clean_index_kwargs, ) if wrap_kwargs is None: wrap_kwargs = {} return self.wrapper.wrap( out, group_by=False, index=pd.RangeIndex(stop=len(out)), columns=new_columns, **wrap_kwargs ) def get_mfe( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, use_returns: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """Get MFE. See `vectorbtpro.portfolio.nb.records.mfe_nb`.""" best_price = self.resolve_shortcut_attr( "best_price", entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, jitted=jitted, chunked=chunked, ) func = jit_reg.resolve_option(nb.mfe_nb, jitted) func = ch_reg.resolve_option(func, chunked) mfe = func( self.get_field_arr("size"), self.get_field_arr("direction"), self.get_field_arr("entry_price"), best_price.values, use_returns=use_returns, ) return self.map_array(mfe, **kwargs) def get_mae( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, use_returns: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArray: """Get MAE. See `vectorbtpro.portfolio.nb.records.mae_nb`.""" worst_price = self.resolve_shortcut_attr( "worst_price", entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, jitted=jitted, chunked=chunked, ) func = jit_reg.resolve_option(nb.mae_nb, jitted) func = ch_reg.resolve_option(func, chunked) mae = func( self.get_field_arr("size"), self.get_field_arr("direction"), self.get_field_arr("entry_price"), worst_price.values, use_returns=use_returns, ) return self.map_array(mae, **kwargs) def get_expanding_mfe( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, use_returns: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.SeriesFrame: """Get expanding MFE. See `vectorbtpro.portfolio.nb.records.expanding_mfe_nb`.""" expanding_best_price = self.resolve_shortcut_attr( "expanding_best_price", entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, jitted=jitted, **kwargs, ) func = jit_reg.resolve_option(nb.expanding_mfe_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.values, expanding_best_price.values, use_returns=use_returns, ) return ArrayWrapper.from_obj(expanding_best_price).wrap(out) def get_expanding_mae( self, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, use_returns: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.SeriesFrame: """Get expanding MAE. See `vectorbtpro.portfolio.nb.records.expanding_mae_nb`.""" expanding_worst_price = self.resolve_shortcut_attr( "expanding_worst_price", entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, jitted=jitted, **kwargs, ) func = jit_reg.resolve_option(nb.expanding_mae_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.values, expanding_worst_price.values, use_returns=use_returns, ) return ArrayWrapper.from_obj(expanding_worst_price).wrap(out) def get_edge_ratio( self, volatility: tp.Optional[tp.ArrayLike] = None, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get edge ratio. See `vectorbtpro.portfolio.nb.records.edge_ratio_nb`. If `volatility` is None, calculates the 14-period ATR if both high and low are provided, otherwise the 14-period rolling standard deviation.""" if self._close is None: raise ValueError("Must provide close") if volatility is None: if self._high is not None and self._low is not None: from vectorbtpro.indicators.nb import atr_nb from vectorbtpro.generic.enums import WType if self._high is None or self._low is None: raise ValueError("Must provide high and low for ATR calculation") volatility = atr_nb( high=to_2d_array(self._high), low=to_2d_array(self._low), close=to_2d_array(self._close), window=14, wtype=WType.Wilder, )[1] else: from vectorbtpro.indicators.nb import msd_nb from vectorbtpro.generic.enums import WType volatility = msd_nb( close=to_2d_array(self._close), window=14, wtype=WType.Wilder, ) else: volatility = broadcast_to(volatility, self.wrapper, to_pd=False, keep_flex=True) col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.edge_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.values, col_map, self._open, self._high, self._low, self._close, volatility, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, ) if wrap_kwargs is None: wrap_kwargs = {} return self.wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs) def get_running_edge_ratio( self, volatility: tp.Optional[tp.ArrayLike] = None, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, incl_shorter: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get running edge ratio. See `vectorbtpro.portfolio.nb.records.running_edge_ratio_nb`. If `volatility` is None, calculates the 14-period ATR if both high and low are provided, otherwise the 14-period rolling standard deviation.""" if self._close is None: raise ValueError("Must provide close") if volatility is None: if self._high is not None and self._low is not None: from vectorbtpro.indicators.nb import atr_nb from vectorbtpro.generic.enums import WType if self._high is None or self._low is None: raise ValueError("Must provide high and low for ATR calculation") volatility = atr_nb( high=to_2d_array(self._high), low=to_2d_array(self._low), close=to_2d_array(self._close), window=14, wtype=WType.Wilder, )[1] else: from vectorbtpro.indicators.nb import msd_nb from vectorbtpro.generic.enums import WType volatility = msd_nb( close=to_2d_array(self._close), window=14, wtype=WType.Wilder, ) else: volatility = broadcast_to(volatility, self.wrapper, to_pd=False, keep_flex=True) col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.running_edge_ratio_nb, jitted) out = func( self.values, col_map, self._open, self._high, self._low, self._close, volatility, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, incl_shorter=incl_shorter, ) if wrap_kwargs is None: wrap_kwargs = {} return self.wrapper.wrap(out, group_by=group_by, index=pd.RangeIndex(stop=len(out)), **wrap_kwargs) @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Trades.stats`. Merges `vectorbtpro.generic.ranges.Ranges.stats_defaults` and `stats` from `vectorbtpro._settings.trades`.""" from vectorbtpro._settings import settings trades_stats_cfg = settings["trades"]["stats"] return merge_dicts(Ranges.stats_defaults.__get__(self), trades_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), first_trade_start=dict( title="First Trade Start", calc_func="entry_idx.nth", n=0, wrap_kwargs=dict(to_index=True), tags=["trades", "index"], ), last_trade_end=dict( title="Last Trade End", calc_func="exit_idx.nth", n=-1, wrap_kwargs=dict(to_index=True), tags=["trades", "index"], ), coverage=dict( title="Coverage", calc_func="coverage", overlapping=False, normalize=False, apply_to_timedelta=True, tags=["ranges", "coverage"], ), overlap_coverage=dict( title="Overlap Coverage", calc_func="coverage", overlapping=True, normalize=False, apply_to_timedelta=True, tags=["ranges", "coverage"], ), total_records=dict(title="Total Records", calc_func="count", tags="records"), total_long_trades=dict( title="Total Long Trades", calc_func="direction_long.count", tags=["trades", "long"] ), total_short_trades=dict( title="Total Short Trades", calc_func="direction_short.count", tags=["trades", "short"] ), total_closed_trades=dict( title="Total Closed Trades", calc_func="status_closed.count", tags=["trades", "closed"] ), total_open_trades=dict(title="Total Open Trades", calc_func="status_open.count", tags=["trades", "open"]), open_trade_pnl=dict(title="Open Trade PnL", calc_func="status_open.pnl.sum", tags=["trades", "open"]), win_rate=dict( title="Win Rate [%]", calc_func="status_closed.get_win_rate", post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['trades', *incl_open_tags]"), ), winning_streak=dict( title="Max Win Streak", calc_func=RepEval("'winning_streak.max' if incl_open else 'status_closed.winning_streak.max'"), wrap_kwargs=dict(dtype=pd.Int64Dtype()), tags=RepEval("['trades', *incl_open_tags, 'streak']"), ), losing_streak=dict( title="Max Loss Streak", calc_func=RepEval("'losing_streak.max' if incl_open else 'status_closed.losing_streak.max'"), wrap_kwargs=dict(dtype=pd.Int64Dtype()), tags=RepEval("['trades', *incl_open_tags, 'streak']"), ), best_trade=dict( title="Best Trade [%]", calc_func=RepEval("'returns.max' if incl_open else 'status_closed.returns.max'"), post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['trades', *incl_open_tags]"), ), worst_trade=dict( title="Worst Trade [%]", calc_func=RepEval("'returns.min' if incl_open else 'status_closed.returns.min'"), post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['trades', *incl_open_tags]"), ), avg_winning_trade=dict( title="Avg Winning Trade [%]", calc_func=RepEval("'winning.returns.mean' if incl_open else 'status_closed.winning.returns.mean'"), post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['trades', *incl_open_tags, 'winning']"), ), avg_losing_trade=dict( title="Avg Losing Trade [%]", calc_func=RepEval("'losing.returns.mean' if incl_open else 'status_closed.losing.returns.mean'"), post_calc_func=lambda self, out, settings: out * 100, tags=RepEval("['trades', *incl_open_tags, 'losing']"), ), avg_winning_trade_duration=dict( title="Avg Winning Trade Duration", calc_func=RepEval("'winning.avg_duration' if incl_open else 'status_closed.winning.get_avg_duration'"), fill_wrap_kwargs=True, tags=RepEval("['trades', *incl_open_tags, 'winning', 'duration']"), ), avg_losing_trade_duration=dict( title="Avg Losing Trade Duration", calc_func=RepEval("'losing.avg_duration' if incl_open else 'status_closed.losing.get_avg_duration'"), fill_wrap_kwargs=True, tags=RepEval("['trades', *incl_open_tags, 'losing', 'duration']"), ), profit_factor=dict( title="Profit Factor", calc_func=RepEval("'profit_factor' if incl_open else 'status_closed.get_profit_factor'"), tags=RepEval("['trades', *incl_open_tags]"), ), expectancy=dict( title="Expectancy", calc_func=RepEval("'expectancy' if incl_open else 'status_closed.get_expectancy'"), tags=RepEval("['trades', *incl_open_tags]"), ), sqn=dict( title="SQN", calc_func=RepEval("'sqn' if incl_open else 'status_closed.get_sqn'"), tags=RepEval("['trades', *incl_open_tags]"), ), edge_ratio=dict( title="Edge Ratio", calc_func=RepEval("'edge_ratio' if incl_open else 'status_closed.get_edge_ratio'"), tags=RepEval("['trades', *incl_open_tags]"), ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot_pnl( self, column: tp.Optional[tp.Label] = None, group_by: tp.GroupByLike = False, pct_scale: bool = False, marker_size_range: tp.Tuple[float, float] = (7, 14), opacity_range: tp.Tuple[float, float] = (0.75, 0.9), closed_trace_kwargs: tp.KwargsLike = None, closed_profit_trace_kwargs: tp.KwargsLike = None, closed_loss_trace_kwargs: tp.KwargsLike = None, open_trace_kwargs: tp.KwargsLike = None, hline_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot trade PnL or returns. Args: column (str): Name of the column to plot. group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`. pct_scale (bool): Whether to set y-axis to `Trades.returns`, otherwise to `Trades.pnl`. marker_size_range (tuple): Range of marker size. opacity_range (tuple): Range of marker opacity. closed_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed" markers. closed_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Profit" markers. closed_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Loss" markers. open_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Open" markers. hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for zeroline. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=7) >>> price = pd.Series([1., 2., 3., 4., 3., 2., 1.], index=index) >>> orders = pd.Series([1., -0.5, -0.5, 2., -0.5, -0.5, -0.5], index=index) >>> pf = vbt.Portfolio.from_orders(price, orders) >>> pf.trades.plot_pnl().show() ``` ![](/assets/images/api/trades_plot_pnl.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/trades_plot_pnl.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure, get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=group_by) if closed_trace_kwargs is None: closed_trace_kwargs = {} if closed_profit_trace_kwargs is None: closed_profit_trace_kwargs = {} if closed_loss_trace_kwargs is None: closed_loss_trace_kwargs = {} if open_trace_kwargs is None: open_trace_kwargs = {} if hline_shape_kwargs is None: hline_shape_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} marker_size_range = tuple(marker_size_range) xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] if fig is None: fig = make_figure() def_layout_kwargs = {xaxis: {}, yaxis: {}} if pct_scale: def_layout_kwargs[yaxis]["tickformat"] = ".2%" def_layout_kwargs[yaxis]["title"] = "Return" else: def_layout_kwargs[yaxis]["title"] = "PnL" fig.update_layout(**def_layout_kwargs) fig.update_layout(**layout_kwargs) x_domain = get_domain(xref, fig) y_domain = get_domain(yref, fig) if self_col.count() > 0: # Extract information exit_idx = self_col.get_map_field_to_index("exit_idx") pnl = self_col.get_field_arr("pnl") returns = self_col.get_field_arr("return") status = self_col.get_field_arr("status") valid_mask = ~np.isnan(returns) neutral_mask = (pnl == 0) & valid_mask profit_mask = (pnl > 0) & valid_mask loss_mask = (pnl < 0) & valid_mask marker_size = min_rel_rescale(np.abs(returns), marker_size_range) opacity = max_rel_rescale(np.abs(returns), opacity_range) open_mask = status == TradeStatus.Open closed_profit_mask = (~open_mask) & profit_mask closed_loss_mask = (~open_mask) & loss_mask open_mask &= ~neutral_mask def _plot_scatter(mask, name, color, kwargs): if np.any(mask): if self_col.get_field_setting("parent_id", "ignore", False): customdata, hovertemplate = self_col.prepare_customdata( incl_fields=["id", "exit_idx", "pnl", "return"], mask=mask ) else: customdata, hovertemplate = self_col.prepare_customdata( incl_fields=["id", "parent_id", "exit_idx", "pnl", "return"], mask=mask ) _kwargs = merge_dicts( dict( x=exit_idx[mask], y=returns[mask] if pct_scale else pnl[mask], mode="markers", marker=dict( symbol="circle", color=color, size=marker_size[mask], opacity=opacity[mask], line=dict(width=1, color=adjust_lightness(color)), ), name=name, customdata=customdata, hovertemplate=hovertemplate, ), kwargs, ) scatter = go.Scatter(**_kwargs) fig.add_trace(scatter, **add_trace_kwargs) # Plot Closed - Neutral scatter _plot_scatter(neutral_mask, "Closed", plotting_cfg["contrast_color_schema"]["gray"], closed_trace_kwargs) # Plot Closed - Profit scatter _plot_scatter( closed_profit_mask, "Closed - Profit", plotting_cfg["contrast_color_schema"]["green"], closed_profit_trace_kwargs, ) # Plot Closed - Profit scatter _plot_scatter( closed_loss_mask, "Closed - Loss", plotting_cfg["contrast_color_schema"]["red"], closed_loss_trace_kwargs, ) # Plot Open scatter _plot_scatter(open_mask, "Open", plotting_cfg["contrast_color_schema"]["orange"], open_trace_kwargs) # Plot zeroline fig.add_shape( **merge_dicts( dict( type="line", xref="paper", yref=yref, x0=x_domain[0], y0=0, x1=x_domain[1], y1=0, line=dict( color="gray", dash="dash", ), ), hline_shape_kwargs, ) ) return fig def plot_returns(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_pnl` for `Trades.returns`.""" return self.plot_pnl( *args, pct_scale=True, **kwargs, ) def plot_against_pnl( self, field: tp.Union[str, tp.Array1d, MappedArray], field_label: tp.Optional[str] = None, column: tp.Optional[tp.Label] = None, group_by: tp.GroupByLike = False, pct_scale: bool = False, field_pct_scale: bool = False, closed_trace_kwargs: tp.KwargsLike = None, closed_profit_trace_kwargs: tp.KwargsLike = None, closed_loss_trace_kwargs: tp.KwargsLike = None, open_trace_kwargs: tp.KwargsLike = None, hline_shape_kwargs: tp.KwargsLike = None, vline_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot a field against PnL or returns. Args: field (str, MappedArray, or array_like): Field to be plotted. Can be also provided as a mapped array or 1-dim array. field_label (str): Label of the field. column (str): Name of the column to plot. group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`. pct_scale (bool): Whether to set x-axis to `Trades.returns`, otherwise to `Trades.pnl`. field_pct_scale (bool): Whether to make y-axis a percentage scale. closed_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed" markers. closed_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Profit" markers. closed_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Loss" markers. open_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Open" markers. hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for horizontal zeroline. vline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for vertical zeroline. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=10) >>> price = pd.Series([1., 2., 3., 4., 5., 6., 5., 3., 2., 1.], index=index) >>> orders = pd.Series([1., -0.5, 0., -0.5, 2., 0., -0.5, -0.5, 0., -0.5], index=index) >>> pf = vbt.Portfolio.from_orders(price, orders) >>> trades = pf.trades >>> trades.plot_against_pnl("MFE").show() ``` ![](/assets/images/api/trades_plot_against_pnl.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/trades_plot_against_pnl.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure, get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=group_by) if closed_trace_kwargs is None: closed_trace_kwargs = {} if closed_profit_trace_kwargs is None: closed_profit_trace_kwargs = {} if closed_loss_trace_kwargs is None: closed_loss_trace_kwargs = {} if open_trace_kwargs is None: open_trace_kwargs = {} if hline_shape_kwargs is None: hline_shape_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] if isinstance(field, str): if field_label is None: field_label = field field = getattr(self_col, field.lower()) if isinstance(field, MappedArray): field = field.values if field_label is None: field_label = "Field" if fig is None: fig = make_figure() def_layout_kwargs = {xaxis: {}, yaxis: {}} if pct_scale: def_layout_kwargs[xaxis]["tickformat"] = ".2%" def_layout_kwargs[xaxis]["title"] = "Return" else: def_layout_kwargs[xaxis]["title"] = "PnL" if field_pct_scale: def_layout_kwargs[yaxis]["tickformat"] = ".2%" def_layout_kwargs[yaxis]["title"] = field_label fig.update_layout(**def_layout_kwargs) fig.update_layout(**layout_kwargs) x_domain = get_domain(xref, fig) y_domain = get_domain(yref, fig) if self_col.count() > 0: # Extract information pnl = self_col.get_field_arr("pnl") returns = self_col.get_field_arr("return") status = self_col.get_field_arr("status") valid_mask = ~np.isnan(returns) neutral_mask = (pnl == 0) & valid_mask profit_mask = (pnl > 0) & valid_mask loss_mask = (pnl < 0) & valid_mask open_mask = status == TradeStatus.Open closed_profit_mask = (~open_mask) & profit_mask closed_loss_mask = (~open_mask) & loss_mask open_mask &= ~neutral_mask def _plot_scatter(mask, name, color, kwargs): if np.any(mask): if self_col.get_field_setting("parent_id", "ignore", False): customdata, hovertemplate = self_col.prepare_customdata( incl_fields=["id", "exit_idx", "pnl", "return"], mask=mask ) else: customdata, hovertemplate = self_col.prepare_customdata( incl_fields=["id", "parent_id", "exit_idx", "pnl", "return"], mask=mask ) _kwargs = merge_dicts( dict( x=returns[mask] if pct_scale else pnl[mask], y=field[mask], mode="markers", marker=dict( symbol="circle", color=color, size=7, line=dict(width=1, color=adjust_lightness(color)), ), name=name, customdata=customdata, hovertemplate=hovertemplate, ), kwargs, ) scatter = go.Scatter(**_kwargs) fig.add_trace(scatter, **add_trace_kwargs) # Plot Closed - Neutral scatter _plot_scatter(neutral_mask, "Closed", plotting_cfg["contrast_color_schema"]["gray"], closed_trace_kwargs) # Plot Closed - Profit scatter _plot_scatter( closed_profit_mask, "Closed - Profit", plotting_cfg["contrast_color_schema"]["green"], closed_profit_trace_kwargs, ) # Plot Closed - Profit scatter _plot_scatter( closed_loss_mask, "Closed - Loss", plotting_cfg["contrast_color_schema"]["red"], closed_loss_trace_kwargs, ) # Plot Open scatter _plot_scatter(open_mask, "Open", plotting_cfg["contrast_color_schema"]["orange"], open_trace_kwargs) # Plot zerolines fig.add_shape( **merge_dicts( dict( type="line", xref="paper", yref=yref, x0=x_domain[0], y0=0, x1=x_domain[1], y1=0, line=dict( color="gray", dash="dash", ), ), hline_shape_kwargs, ) ) fig.add_shape( **merge_dicts( dict( type="line", xref=xref, yref="paper", x0=0, y0=y_domain[0], x1=0, y1=y_domain[1], line=dict( color="gray", dash="dash", ), ), vline_shape_kwargs, ) ) return fig def plot_mfe(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_against_pnl` for `Trades.mfe`.""" return self.plot_against_pnl( *args, field="mfe", field_label="MFE", **kwargs, ) def plot_mfe_returns(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_against_pnl` for `Trades.mfe_returns`.""" return self.plot_against_pnl( *args, field="mfe_returns", field_label="MFE Return", pct_scale=True, field_pct_scale=True, **kwargs, ) def plot_mae(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_against_pnl` for `Trades.mae`.""" return self.plot_against_pnl( *args, field="mae", field_label="MAE", **kwargs, ) def plot_mae_returns(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_against_pnl` for `Trades.mae_returns`.""" return self.plot_against_pnl( *args, field="mae_returns", field_label="MAE Return", pct_scale=True, field_pct_scale=True, **kwargs, ) def plot_expanding( self, field: tp.Union[str, tp.Array1d, MappedArray], field_label: tp.Optional[str] = None, column: tp.Optional[tp.Label] = None, group_by: tp.GroupByLike = False, plot_bands: bool = False, colorize: tp.Union[bool, str, tp.Callable] = "last", field_pct_scale: bool = False, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.BaseFigure: """Plot projections of an expanding field. Args: field (str or array_like): Field to be plotted. Can be also provided as a 2-dim array. field_label (str): Label of the field. column (str): Name of the column to plot. Optional. group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`. plot_bands (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. colorize (bool, str or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. field_pct_scale (bool): Whether to make y-axis a percentage scale. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **kwargs: Keyword arguments passed to `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`. Usage: ```pycon >>> index = pd.date_range("2020", periods=10) >>> price = pd.Series([1., 2., 3., 2., 4., 5., 6., 5., 6., 7.], index=index) >>> orders = pd.Series([1., 0., 0., -2., 0., 0., 2., 0., 0., -1.], index=index) >>> pf = vbt.Portfolio.from_orders(price, orders) >>> pf.trades.plot_expanding("MFE").show() ``` ![](/assets/images/api/trades_plot_expanding.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/trades_plot_expanding.dark.svg#only-dark){: .iimg loading=lazy } """ if column is not None: self_col = self.select_col(column=column, group_by=group_by) else: self_col = self if isinstance(field, str): if field_label is None: field_label = field if not field.lower().startswith("expanding_"): field = "expanding_" + field field = getattr(self_col, field.lower()) if isinstance(field, MappedArray): field = field.values if field_label is None: field_label = "Field" field = to_pd_array(field) fig = field.vbt.plot_projections( add_trace_kwargs=add_trace_kwargs, fig=fig, plot_bands=plot_bands, colorize=colorize, **kwargs, ) xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x" yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y" xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] if field_label is not None and "title" not in kwargs.get(yaxis, {}): fig.update_layout(**{yaxis: dict(title=field_label)}) if field_pct_scale and "tickformat" not in kwargs.get(yaxis, {}): fig.update_layout(**{yaxis: dict(tickformat=".2%")}) return fig def plot_expanding_mfe(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_expanding` for `Trades.expanding_mfe`.""" return self.plot_expanding( *args, field="expanding_mfe", field_label="MFE", **kwargs, ) def plot_expanding_mfe_returns(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_expanding` for `Trades.expanding_mfe_returns`.""" return self.plot_expanding( *args, field="expanding_mfe_returns", field_label="MFE Return", field_pct_scale=True, **kwargs, ) def plot_expanding_mae(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_expanding` for `Trades.expanding_mae`.""" return self.plot_expanding( *args, field="expanding_mae", field_label="MAE", **kwargs, ) def plot_expanding_mae_returns(self, *args, **kwargs) -> tp.BaseFigure: """`Trades.plot_expanding` for `Trades.expanding_mae_returns`.""" return self.plot_expanding( *args, field="expanding_mae_returns", field_label="MAE Return", field_pct_scale=True, **kwargs, ) def plot_running_edge_ratio( self, column: tp.Optional[tp.Label] = None, volatility: tp.Optional[tp.ArrayLike] = None, entry_price_open: bool = False, exit_price_close: bool = False, max_duration: tp.Optional[int] = None, incl_shorter: bool = False, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, xref: str = "x", yref: str = "y", hline_shape_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: """Plot one column/group of edge ratio. `**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`.""" from vectorbtpro.utils.figure import get_domain running_edge_ratio = self.resolve_shortcut_attr( "running_edge_ratio", volatility=volatility, entry_price_open=entry_price_open, exit_price_close=exit_price_close, max_duration=max_duration, incl_shorter=incl_shorter, group_by=group_by, jitted=jitted, ) running_edge_ratio = self.select_col_from_obj( running_edge_ratio, column, wrapper=self.wrapper.regroup(group_by) ) kwargs = merge_dicts( dict( trace_kwargs=dict(name="Edge Ratio"), other_trace_kwargs="hidden", ), kwargs, ) fig = running_edge_ratio.vbt.plot_against(1, **kwargs) x_domain = get_domain(xref, fig) fig.add_shape( **merge_dicts( dict( type="line", line=dict( color="gray", dash="dash", ), xref="paper", yref=yref, x0=x_domain[0], y0=1.0, x1=x_domain[1], y1=1.0, ), hline_shape_kwargs, ) ) return fig def plot( self, column: tp.Optional[tp.Label] = None, plot_ohlc: bool = True, plot_close: bool = True, plot_markers: bool = True, plot_zones: bool = True, plot_by_type: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, entry_trace_kwargs: tp.KwargsLike = None, exit_trace_kwargs: tp.KwargsLike = None, exit_profit_trace_kwargs: tp.KwargsLike = None, exit_loss_trace_kwargs: tp.KwargsLike = None, active_trace_kwargs: tp.KwargsLike = None, profit_shape_kwargs: tp.KwargsLike = None, loss_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot trades. Args: column (str): Name of the column to plot. plot_ohlc (bool): Whether to plot OHLC. plot_close (bool): Whether to plot close. plot_markers (bool): Whether to plot markers. plot_zones (bool): Whether to plot zones. plot_by_type (bool): Whether to plot exit trades by type. Otherwise, the appearance will be controlled using `exit_trace_kwargs`. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Trades.close`. entry_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Entry" markers. exit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Exit" markers. exit_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Exit - Profit" markers. exit_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Exit - Loss" markers. active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Active" markers. profit_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for profit zones. loss_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for loss zones. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=7) >>> price = pd.Series([1., 2., 3., 4., 3., 2., 1.], index=index) >>> size = pd.Series([1., -0.5, -0.5, 2., -0.5, -0.5, -0.5], index=index) >>> pf = vbt.Portfolio.from_orders(price, size) >>> pf.trades.plot().show() ``` ![](/assets/images/api/trades_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/trades_plot.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if entry_trace_kwargs is None: entry_trace_kwargs = {} if exit_trace_kwargs is None: exit_trace_kwargs = {} if exit_profit_trace_kwargs is None: exit_profit_trace_kwargs = {} if exit_loss_trace_kwargs is None: exit_loss_trace_kwargs = {} if active_trace_kwargs is None: active_trace_kwargs = {} if profit_shape_kwargs is None: profit_shape_kwargs = {} if loss_shape_kwargs is None: loss_shape_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) # Plot close if ( plot_ohlc and self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc_df = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc_df.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close and self_col._close is not None: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Extract information entry_idx = self_col.get_map_field_to_index("entry_idx", minus_one_to_zero=True) entry_price = self_col.get_field_arr("entry_price") exit_idx = self_col.get_map_field_to_index("exit_idx") exit_price = self_col.get_field_arr("exit_price") pnl = self_col.get_field_arr("pnl") status = self_col.get_field_arr("status") duration = to_1d_array( self_col.wrapper.arr_to_timedelta( self_col.duration.values, to_pd=True, silence_warnings=True, ).astype(str) ) if plot_markers: # Plot Entry markers if self_col.get_field_setting("parent_id", "ignore", False): entry_customdata, entry_hovertemplate = self_col.prepare_customdata( incl_fields=[ "id", "entry_order_id", "entry_idx", "size", "entry_price", "entry_fees", "direction", ] ) else: entry_customdata, entry_hovertemplate = self_col.prepare_customdata( incl_fields=[ "id", "entry_order_id", "parent_id", "entry_idx", "size", "entry_price", "entry_fees", "direction", ] ) _entry_trace_kwargs = merge_dicts( dict( x=entry_idx, y=entry_price, mode="markers", marker=dict( symbol="square", color=plotting_cfg["contrast_color_schema"]["blue"], size=7, line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["blue"])), ), name="Entry", customdata=entry_customdata, hovertemplate=entry_hovertemplate, ), entry_trace_kwargs, ) entry_scatter = go.Scatter(**_entry_trace_kwargs) fig.add_trace(entry_scatter, **add_trace_kwargs) # Plot end markers def _plot_end_markers(mask, name, color, kwargs, incl_status=False) -> None: if np.any(mask): if self_col.get_field_setting("parent_id", "ignore", False): exit_customdata, exit_hovertemplate = self_col.prepare_customdata( incl_fields=[ "id", "exit_order_id", "exit_idx", "size", "exit_price", "exit_fees", "pnl", "return", "direction", *(("status",) if incl_status else ()), ], append_info=[(duration, "Duration")], mask=mask, ) else: exit_customdata, exit_hovertemplate = self_col.prepare_customdata( incl_fields=[ "id", "exit_order_id", "parent_id", "exit_idx", "size", "exit_price", "exit_fees", "pnl", "return", "direction", *(("status",) if incl_status else ()), ], append_info=[(duration, "Duration")], mask=mask, ) _kwargs = merge_dicts( dict( x=exit_idx[mask], y=exit_price[mask], mode="markers", marker=dict( symbol="square", color=color, size=7, line=dict(width=1, color=adjust_lightness(color)), ), name=name, customdata=exit_customdata, hovertemplate=exit_hovertemplate, ), kwargs, ) scatter = go.Scatter(**_kwargs) fig.add_trace(scatter, **add_trace_kwargs) if plot_by_type: # Plot Exit markers _plot_end_markers( (status == TradeStatus.Closed) & (pnl == 0.0), "Exit", plotting_cfg["contrast_color_schema"]["gray"], exit_trace_kwargs, ) # Plot Exit - Profit markers _plot_end_markers( (status == TradeStatus.Closed) & (pnl > 0.0), "Exit - Profit", plotting_cfg["contrast_color_schema"]["green"], exit_profit_trace_kwargs, ) # Plot Exit - Loss markers _plot_end_markers( (status == TradeStatus.Closed) & (pnl < 0.0), "Exit - Loss", plotting_cfg["contrast_color_schema"]["red"], exit_loss_trace_kwargs, ) # Plot Active markers _plot_end_markers( status == TradeStatus.Open, "Active", plotting_cfg["contrast_color_schema"]["orange"], active_trace_kwargs, ) else: # Plot Exit markers _plot_end_markers( np.full(len(status), True), "Exit", plotting_cfg["contrast_color_schema"]["pink"], exit_trace_kwargs, incl_status=True, ) if plot_zones: # Plot profit zones self_col.winning.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict( yref=Rep("yref"), y0=RepFunc(lambda record: record["entry_price"]), y1=RepFunc(lambda record: record["exit_price"]), fillcolor=plotting_cfg["contrast_color_schema"]["green"], ), profit_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) # Plot loss zones self_col.losing.plot_shapes( plot_ohlc=False, plot_close=False, shape_kwargs=merge_dicts( dict( yref=Rep("yref"), y0=RepFunc(lambda record: record["entry_price"]), y1=RepFunc(lambda record: record["exit_price"]), fillcolor=plotting_cfg["contrast_color_schema"]["red"], ), loss_shape_kwargs, ), add_trace_kwargs=add_trace_kwargs, xref=xref, yref=yref, fig=fig, ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Trades.plots`. Merges `vectorbtpro.generic.ranges.Ranges.plots_defaults` and `plots` from `vectorbtpro._settings.trades`.""" from vectorbtpro._settings import settings trades_plots_cfg = settings["trades"]["plots"] return merge_dicts(Ranges.plots_defaults.__get__(self), trades_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot=dict( title="Trades", yaxis_kwargs=dict(title="Price"), check_is_not_grouped=True, plot_func="plot", tags="trades", ), plot_pnl=dict( title="Trade PnL", yaxis_kwargs=dict(title="Trade PnL"), check_is_not_grouped=True, plot_func="plot_pnl", tags="trades", ), ) ) @property def subplots(self) -> Config: return self._subplots Trades.override_field_config_doc(__pdoc__) Trades.override_metrics_doc(__pdoc__) Trades.override_subplots_doc(__pdoc__) # ############# EntryTrades ############# # entry_trades_field_config = ReadonlyConfig( dict(settings={"id": dict(title="Entry Trade Id"), "idx": dict(name="entry_idx")}) # remap field of Records, ) """_""" __pdoc__[ "entry_trades_field_config" ] = f"""Field config for `EntryTrades`. ```python {entry_trades_field_config.prettify()} ``` """ EntryTradesT = tp.TypeVar("EntryTradesT", bound="EntryTrades") @override_field_config(entry_trades_field_config) class EntryTrades(Trades): """Extends `Trades` for working with entry trade records.""" @property def field_config(self) -> Config: return self._field_config @classmethod def from_orders( cls: tp.Type[EntryTradesT], orders: Orders, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, close: tp.Optional[tp.ArrayLike] = None, init_position: tp.ArrayLike = 0.0, init_price: tp.ArrayLike = np.nan, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> EntryTradesT: """Build `EntryTrades` from `vectorbtpro.portfolio.orders.Orders`.""" if open is None: open = orders._open if high is None: high = orders._high if low is None: low = orders._low if close is None: close = orders._close func = jit_reg.resolve_option(nb.get_entry_trades_nb, jitted) func = ch_reg.resolve_option(func, chunked) trade_records_arr = func( orders.values, to_2d_array(orders.wrapper.wrap(close, group_by=False)), orders.col_mapper.col_map, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), sim_start=None if sim_start is None else to_1d_array(sim_start), sim_end=None if sim_end is None else to_1d_array(sim_end), ) return cls.from_records( orders.wrapper, trade_records_arr, open=open, high=high, low=low, close=close, **kwargs, ) def plot_signals( self, column: tp.Optional[tp.Label] = None, plot_ohlc: bool = True, plot_close: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, long_entry_trace_kwargs: tp.KwargsLike = None, short_entry_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot entry trade signals. Args: column (str): Name of the column to plot. plot_ohlc (bool): Whether to plot OHLC. plot_close (bool): Whether to plot close. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `EntryTrades.close`. long_entry_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Long Entry" markers. short_entry_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Short Entry" markers. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=7) >>> price = pd.Series([1, 2, 3, 2, 3, 4, 3], index=index) >>> orders = pd.Series([1, 0, -1, 0, -1, 2, -2], index=index) >>> pf = vbt.Portfolio.from_orders(price, orders) >>> pf.entry_trades.plot_signals().show() ``` ![](/assets/images/api/entry_trades_plot_signals.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/entry_trades_plot_signals.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if long_entry_trace_kwargs is None: long_entry_trace_kwargs = {} if short_entry_trace_kwargs is None: short_entry_trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) # Plot close if ( plot_ohlc and self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc_df = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc_df.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close and self_col._close is not None: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Extract information entry_idx = self_col.get_map_field_to_index("entry_idx", minus_one_to_zero=True) entry_price = self_col.get_field_arr("entry_price") direction = self_col.get_field_arr("direction") def _plot_entry_markers(mask, name, color, kwargs): if np.any(mask): entry_customdata, entry_hovertemplate = self_col.prepare_customdata( incl_fields=[ "id", "entry_order_id", "parent_id", "entry_idx", "size", "entry_price", "entry_fees", "pnl", "return", "status", ], mask=mask, ) _kwargs = merge_dicts( dict( x=entry_idx[mask], y=entry_price[mask], mode="markers", marker=dict( symbol="circle", color="rgba(0, 0, 0, 0)", size=15, line=dict( color=color, width=2, ), ), name=name, customdata=entry_customdata, hovertemplate=entry_hovertemplate, ), kwargs, ) scatter = go.Scatter(**_kwargs) fig.add_trace(scatter, **add_trace_kwargs) # Plot Long Entry markers _plot_entry_markers( direction == TradeDirection.Long, "Long Entry", plotting_cfg["contrast_color_schema"]["green"], long_entry_trace_kwargs, ) # Plot Short Entry markers _plot_entry_markers( direction == TradeDirection.Short, "Short Entry", plotting_cfg["contrast_color_schema"]["red"], short_entry_trace_kwargs, ) return fig EntryTrades.override_field_config_doc(__pdoc__) # ############# ExitTrades ############# # exit_trades_field_config = ReadonlyConfig(dict(settings={"id": dict(title="Exit Trade Id")})) """_""" __pdoc__[ "exit_trades_field_config" ] = f"""Field config for `ExitTrades`. ```python {exit_trades_field_config.prettify()} ``` """ ExitTradesT = tp.TypeVar("ExitTradesT", bound="ExitTrades") @override_field_config(exit_trades_field_config) class ExitTrades(Trades): """Extends `Trades` for working with exit trade records.""" @property def field_config(self) -> Config: return self._field_config @classmethod def from_orders( cls: tp.Type[ExitTradesT], orders: Orders, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, close: tp.Optional[tp.ArrayLike] = None, init_position: tp.ArrayLike = 0.0, init_price: tp.ArrayLike = np.nan, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> ExitTradesT: """Build `ExitTrades` from `vectorbtpro.portfolio.orders.Orders`.""" if open is None: open = orders._open if high is None: high = orders._high if low is None: low = orders._low if close is None: close = orders._close func = jit_reg.resolve_option(nb.get_exit_trades_nb, jitted) func = ch_reg.resolve_option(func, chunked) trade_records_arr = func( orders.values, to_2d_array(orders.wrapper.wrap(close, group_by=False)), orders.col_mapper.col_map, init_position=to_1d_array(init_position), init_price=to_1d_array(init_price), sim_start=None if sim_start is None else to_1d_array(sim_start), sim_end=None if sim_end is None else to_1d_array(sim_end), ) return cls.from_records( orders.wrapper, trade_records_arr, open=open, high=high, low=low, close=close, **kwargs, ) def plot_signals( self, column: tp.Optional[tp.Label] = None, plot_ohlc: bool = True, plot_close: bool = True, ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None, ohlc_trace_kwargs: tp.KwargsLike = None, close_trace_kwargs: tp.KwargsLike = None, long_exit_trace_kwargs: tp.KwargsLike = None, short_exit_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot exit trade signals. Args: column (str): Name of the column to plot. plot_ohlc (bool): Whether to plot OHLC. plot_close (bool): Whether to plot close. ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace. Pass None to use the default. ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`. close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ExitTrades.close`. long_exit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Long Exit" markers. short_exit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Short Exit" markers. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> index = pd.date_range("2020", periods=7) >>> price = pd.Series([1, 2, 3, 2, 3, 4, 3], index=index) >>> orders = pd.Series([1, 0, -1, 0, -1, 2, -2], index=index) >>> pf = vbt.Portfolio.from_orders(price, orders) >>> pf.exit_trades.plot_signals().show() ``` ![](/assets/images/api/exit_trades_plot_signals.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/exit_trades_plot_signals.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import plotly.graph_objects as go from vectorbtpro.utils.figure import make_figure from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] self_col = self.select_col(column=column, group_by=False) if ohlc_trace_kwargs is None: ohlc_trace_kwargs = {} if close_trace_kwargs is None: close_trace_kwargs = {} close_trace_kwargs = merge_dicts( dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"), close_trace_kwargs, ) if long_exit_trace_kwargs is None: long_exit_trace_kwargs = {} if short_exit_trace_kwargs is None: short_exit_trace_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} if fig is None: fig = make_figure() fig.update_layout(**layout_kwargs) # Plot close if ( plot_ohlc and self_col._open is not None and self_col._high is not None and self_col._low is not None and self_col._close is not None ): ohlc_df = pd.DataFrame( { "open": self_col.open, "high": self_col.high, "low": self_col.low, "close": self_col.close, } ) if "opacity" not in ohlc_trace_kwargs: ohlc_trace_kwargs["opacity"] = 0.5 fig = ohlc_df.vbt.ohlcv.plot( ohlc_type=ohlc_type, plot_volume=False, ohlc_trace_kwargs=ohlc_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) elif plot_close and self_col._close is not None: fig = self_col.close.vbt.lineplot( trace_kwargs=close_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, ) if self_col.count() > 0: # Extract information exit_idx = self_col.get_map_field_to_index("exit_idx", minus_one_to_zero=True) exit_price = self_col.get_field_arr("exit_price") direction = self_col.get_field_arr("direction") def _plot_exit_markers(mask, name, color, kwargs): if np.any(mask): exit_customdata, exit_hovertemplate = self_col.prepare_customdata( incl_fields=[ "id", "exit_order_id", "parent_id", "exit_idx", "size", "exit_price", "exit_fees", "pnl", "return", "status", ], mask=mask, ) _kwargs = merge_dicts( dict( x=exit_idx[mask], y=exit_price[mask], mode="markers", marker=dict( symbol="circle", color=color, size=8, ), name=name, customdata=exit_customdata, hovertemplate=exit_hovertemplate, ), kwargs, ) scatter = go.Scatter(**_kwargs) fig.add_trace(scatter, **add_trace_kwargs) # Plot Long Exit markers _plot_exit_markers( direction == TradeDirection.Long, "Long Exit", plotting_cfg["contrast_color_schema"]["green"], long_exit_trace_kwargs, ) # Plot Short Exit markers _plot_exit_markers( direction == TradeDirection.Short, "Short Exit", plotting_cfg["contrast_color_schema"]["red"], short_exit_trace_kwargs, ) return fig ExitTrades.override_field_config_doc(__pdoc__) # ############# Positions ############# # positions_field_config = ReadonlyConfig( dict(settings={"id": dict(title="Position Id"), "parent_id": dict(title="Parent Id", ignore=True)}), ) """_""" __pdoc__[ "positions_field_config" ] = f"""Field config for `Positions`. ```python {positions_field_config.prettify()} ``` """ PositionsT = tp.TypeVar("PositionsT", bound="Positions") @override_field_config(positions_field_config) class Positions(Trades): """Extends `Trades` for working with position records.""" @property def field_config(self) -> Config: return self._field_config @classmethod def from_trades( cls: tp.Type[PositionsT], trades: Trades, open: tp.Optional[tp.ArrayLike] = None, high: tp.Optional[tp.ArrayLike] = None, low: tp.Optional[tp.ArrayLike] = None, close: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> PositionsT: """Build `Positions` from `Trades`.""" if open is None: open = trades._open if high is None: high = trades._high if low is None: low = trades._low if close is None: close = trades._close func = jit_reg.resolve_option(nb.get_positions_nb, jitted) func = ch_reg.resolve_option(func, chunked) position_records_arr = func(trades.values, trades.col_mapper.col_map) return cls.from_records( trades.wrapper, position_records_arr, open=open, high=high, low=low, close=close, **kwargs, ) Positions.override_field_config_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for plotting with Plotly Express.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.px.accessors import * from vectorbtpro.px.decorators import * __import_if_installed__ = dict() __import_if_installed__["accessors"] = "plotly" __import_if_installed__["decorators"] = "plotly" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Pandas accessors for Plotly Express. !!! note Accessors do not utilize caching.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.base.accessors import BaseAccessor, BaseDFAccessor, BaseSRAccessor from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.accessors import register_vbt_accessor, register_df_vbt_accessor, register_sr_vbt_accessor from vectorbtpro.px.decorators import attach_px_methods __all__ = [ "PXAccessor", "PXSRAccessor", "PXDFAccessor", ] @register_vbt_accessor("px") @attach_px_methods class PXAccessor(BaseAccessor): """Accessor for running Plotly Express functions. Accessible via `pd.Series.vbt.px` and `pd.DataFrame.vbt.px`. Usage: ```pycon >>> from vectorbtpro import * >>> pd.Series([1, 2, 3]).vbt.px.bar().show() ``` ![](/assets/images/api/px_bar.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/px_bar.dark.svg#only-dark){: .iimg loading=lazy } """ def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs) @register_sr_vbt_accessor("px") class PXSRAccessor(PXAccessor, BaseSRAccessor): """Accessor for running Plotly Express functions. For Series only. Accessible via `pd.Series.vbt.px`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, _full_init: bool = True, **kwargs, ) -> None: BaseSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: PXAccessor.__init__(self, wrapper, obj=obj, **kwargs) @register_df_vbt_accessor("px") class PXDFAccessor(PXAccessor, BaseDFAccessor): """Accessor for running Plotly Express functions. For DataFrames only. Accessible via `pd.DataFrame.vbt.px`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, _full_init: bool = True, **kwargs, ) -> None: BaseDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: PXAccessor.__init__(self, wrapper, obj=obj, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class decorators for Plotly Express accessors.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("plotly") from inspect import getmembers, isfunction import pandas as pd import plotly.express as px from vectorbtpro import _typing as tp from vectorbtpro.base.reshaping import to_2d_array from vectorbtpro.generic.plotting import clean_labels from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.figure import make_figure def attach_px_methods(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: """Class decorator to attach Plotly Express methods.""" for px_func_name, px_func in getmembers(px, isfunction): if checks.func_accepts_arg(px_func, "data_frame") or px_func_name == "imshow": def plot_method( self, *args, _px_func_name: str = px_func_name, _px_func: tp.Callable = px_func, layout: tp.KwargsLike = None, **kwargs, ) -> tp.BaseFigure: from vectorbtpro._settings import settings layout_cfg = settings["plotting"]["layout"] layout_kwargs = dict( template=kwargs.pop("template", layout_cfg["template"]), width=kwargs.pop("width", layout_cfg["width"]), height=kwargs.pop("height", layout_cfg["height"]), ) layout = merge_dicts(layout_kwargs, layout) # Fix category_orders if "color" in kwargs: if isinstance(kwargs["color"], str): if isinstance(self.obj, pd.DataFrame): if kwargs["color"] in self.obj.columns: category_orders = dict() category_orders[kwargs["color"]] = sorted(self.obj[kwargs["color"]].unique()) kwargs = merge_dicts(dict(category_orders=category_orders), kwargs) # Fix Series name obj = self.obj.copy(deep=False) if isinstance(obj, pd.Series): if obj.name is not None: obj = obj.rename(str(obj.name)) else: obj.columns = clean_labels(obj.columns) obj.index = clean_labels(obj.index) if _px_func_name == "imshow": return make_figure(_px_func(to_2d_array(obj), *args, **layout_kwargs, **kwargs), layout=layout) return make_figure(_px_func(obj, *args, **layout_kwargs, **kwargs), layout=layout) plot_method.__name__ = px_func_name plot_method.__module__ = cls.__module__ plot_method.__qualname__ = f"{cls.__name__}.{plot_method.__name__}" plot_method.__doc__ = f"""Plot using `{px_func.__module__ + '.' + px_func.__name__}`.""" setattr(cls, px_func_name, plot_method) return cls # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with records. Records are the second form of data representation in vectorbtpro. They allow storing sparse event data such as drawdowns, orders, trades, and positions, without converting them back to the matrix form and occupying the user's memory.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.records.base import * from vectorbtpro.records.chunking import * from vectorbtpro.records.col_mapper import * from vectorbtpro.records.mapped_array import * from vectorbtpro.records.nb import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with records. vectorbt works with two different representations of data: matrices and records. A matrix, in this context, is just an array of one-dimensional arrays, each corresponding to a separate feature. The matrix itself holds only one kind of information (one attribute). For example, one can create a matrix for entry signals, with columns being different strategy configurations. But what if the matrix is huge and sparse? What if there is more information we would like to represent by each element? Creating multiple matrices would be a waste of memory. Records make possible representing complex, sparse information in a dense format. They are just an array of one-dimensional arrays of a fixed schema, where each element holds a different kind of information. You can imagine records being a DataFrame, where each row represents a record and each column represents a specific attribute. Read more on structured arrays [here](https://numpy.org/doc/stable/user/basics.rec.html). For example, let's represent two DataFrames as a single record array: ```plaintext a b 0 1.0 5.0 attr1 = 1 2.0 NaN 2 NaN 7.0 3 4.0 8.0 a b 0 9.0 13.0 attr2 = 1 10.0 NaN 2 NaN 15.0 3 12.0 16.0 | v id col idx attr1 attr2 0 0 0 0 1 9 1 1 0 1 2 10 2 2 0 3 4 12 3 0 1 0 5 13 4 1 1 2 7 15 5 2 1 3 8 16 ``` Another advantage of records is that they are not constrained by size. Multiple records can map to a single element in a matrix. For example, one can define multiple orders at the same timestamp, which is impossible to represent in a matrix form without duplicating index entries or using complex data types. Consider the following example: ```pycon >>> from vectorbtpro import * >>> example_dt = np.dtype([ ... ('id', int_), ... ('col', int_), ... ('idx', int_), ... ('some_field', float_) ... ]) >>> records_arr = np.array([ ... (0, 0, 0, 10.), ... (1, 0, 1, 11.), ... (2, 0, 2, 12.), ... (0, 1, 0, 13.), ... (1, 1, 1, 14.), ... (2, 1, 2, 15.), ... (0, 2, 0, 16.), ... (1, 2, 1, 17.), ... (2, 2, 2, 18.) ... ], dtype=example_dt) >>> wrapper = vbt.ArrayWrapper(index=['x', 'y', 'z'], ... columns=['a', 'b', 'c'], ndim=2, freq='1 day') >>> records = vbt.Records(wrapper, records_arr) ``` ## Printing There are two ways to print records: * Raw dataframe that preserves field names and data types: ```pycon >>> records.records id col idx some_field 0 0 0 0 10.0 1 1 0 1 11.0 2 2 0 2 12.0 3 0 1 0 13.0 4 1 1 1 14.0 5 2 1 2 15.0 6 0 2 0 16.0 7 1 2 1 17.0 8 2 2 2 18.0 ``` * Readable dataframe that takes into consideration `Records.field_config`: ```pycon >>> records.readable Id Column Timestamp some_field 0 0 a x 10.0 1 1 a y 11.0 2 2 a z 12.0 3 0 b x 13.0 4 1 b y 14.0 5 2 b z 15.0 6 0 c x 16.0 7 1 c y 17.0 8 2 c z 18.0 ``` ## Mapping `Records` are just [structured arrays](https://numpy.org/doc/stable/user/basics.rec.html) with a bunch of methods and properties for processing them. Their main feature is to map the records array and to reduce it by column (similar to the MapReduce paradigm). The main advantage is that it all happens without conversion to the matrix form and wasting memory resources. `Records` can be mapped to `vectorbtpro.records.mapped_array.MappedArray` in several ways: * Use `Records.map_field` to map a record field: ```pycon >>> records.map_field('some_field') >>> records.map_field('some_field').values array([10., 11., 12., 13., 14., 15., 16., 17., 18.]) ``` * Use `Records.map` to map records using a custom function. ```pycon >>> @njit ... def power_map_nb(record, pow): ... return record.some_field ** pow >>> records.map(power_map_nb, 2) >>> records.map(power_map_nb, 2).values array([100., 121., 144., 169., 196., 225., 256., 289., 324.]) >>> # Map using a meta function >>> @njit ... def power_map_meta_nb(ridx, records, pow): ... return records[ridx].some_field ** pow >>> vbt.Records.map(power_map_meta_nb, records.values, 2, col_mapper=records.col_mapper).values array([100., 121., 144., 169., 196., 225., 256., 289., 324.]) ``` * Use `Records.map_array` to convert an array to `vectorbtpro.records.mapped_array.MappedArray`. ```pycon >>> records.map_array(records_arr['some_field'] ** 2) >>> records.map_array(records_arr['some_field'] ** 2).values array([100., 121., 144., 169., 196., 225., 256., 289., 324.]) ``` * Use `Records.apply` to apply a function on each column/group: ```pycon >>> @njit ... def cumsum_apply_nb(records): ... return np.cumsum(records.some_field) >>> records.apply(cumsum_apply_nb) >>> records.apply(cumsum_apply_nb).values array([10., 21., 33., 13., 27., 42., 16., 33., 51.]) >>> group_by = np.array(['first', 'first', 'second']) >>> records.apply(cumsum_apply_nb, group_by=group_by, apply_per_group=True).values array([10., 21., 33., 46., 60., 75., 16., 33., 51.]) >>> # Apply using a meta function >>> @njit ... def cumsum_apply_meta_nb(idxs, col, records): ... return np.cumsum(records[idxs].some_field) >>> vbt.Records.apply(cumsum_apply_meta_nb, records.values, col_mapper=records.col_mapper).values array([10., 21., 33., 13., 27., 42., 16., 33., 51.]) ``` Notice how cumsum resets at each column in the first example and at each group in the second example. ## Filtering Use `Records.apply_mask` to filter elements per column/group: ```pycon >>> mask = [True, False, True, False, True, False, True, False, True] >>> filtered_records = records.apply_mask(mask) >>> filtered_records.records id col idx some_field 0 0 0 0 10.0 1 2 0 2 12.0 2 1 1 1 14.0 3 0 2 0 16.0 4 2 2 2 18.0 ``` ## Grouping One of the key features of `Records` is that you can perform reducing operations on a group of columns as if they were a single column. Groups can be specified by `group_by`, which can be anything from positions or names of column levels, to a NumPy array with actual groups. There are multiple ways of define grouping: * When creating `Records`, pass `group_by` to `vectorbtpro.base.wrapping.ArrayWrapper`: ```pycon >>> group_by = np.array(['first', 'first', 'second']) >>> grouped_wrapper = wrapper.replace(group_by=group_by) >>> grouped_records = vbt.Records(grouped_wrapper, records_arr) >>> grouped_records.map_field('some_field').mean() first 12.5 second 17.0 dtype: float64 ``` * Regroup an existing `Records`: ```pycon >>> records.regroup(group_by).map_field('some_field').mean() first 12.5 second 17.0 dtype: float64 ``` * Pass `group_by` directly to the mapping method: ```pycon >>> records.map_field('some_field', group_by=group_by).mean() first 12.5 second 17.0 dtype: float64 ``` * Pass `group_by` directly to the reducing method: ```pycon >>> records.map_field('some_field').mean(group_by=group_by) a 11.0 b 14.0 c 17.0 dtype: float64 ``` !!! note Grouping applies only to reducing operations, there is no change to the arrays. ## Indexing Like any other class subclassing `vectorbtpro.base.wrapping.Wrapping`, we can do pandas indexing on a `Records` instance, which forwards indexing operation to each object with columns: ```pycon >>> records['a'].records id col idx some_field 0 0 0 0 10.0 1 1 0 1 11.0 2 2 0 2 12.0 >>> grouped_records['first'].records id col idx some_field 0 0 0 0 10.0 1 1 0 1 11.0 2 2 0 2 12.0 3 0 1 0 13.0 4 1 1 1 14.0 5 2 1 2 15.0 ``` !!! note Changing index (time axis) is not supported. The object should be treated as a Series rather than a DataFrame; for example, use `some_field.iloc[0]` instead of `some_field.iloc[:, 0]` to get the first column. Indexing behavior depends solely upon `vectorbtpro.base.wrapping.ArrayWrapper`. For example, if `group_select` is enabled indexing will be performed on groups when grouped, otherwise on single columns. ## Caching `Records` supports caching. If a method or a property requires heavy computation, it's wrapped with `vectorbtpro.utils.decorators.cached_method` and `vectorbtpro.utils.decorators.cached_property` respectively. Caching can be disabled globally via `vectorbtpro._settings.caching`. !!! note Because of caching, class is meant to be immutable and all properties are read-only. To change any attribute, use the `Records.replace` method and pass changes as keyword arguments. ## Saving and loading Like any other class subclassing `vectorbtpro.utils.pickling.Pickleable`, we can save a `Records` instance to the disk with `Records.save` and load it with `Records.load`. ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Records.metrics`. ```pycon >>> records.stats(column='a') Start x End z Period 3 days 00:00:00 Total Records 3 Name: a, dtype: object ``` `Records.stats` also supports (re-)grouping: ```pycon >>> grouped_records.stats(column='first') Start x End z Period 3 days 00:00:00 Total Records 6 Name: first, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Records.subplots`. This class is too generic to have any subplots, but feel free to add custom subplots to your subclass. ## Extending `Records` class can be extended by subclassing. In case some of our fields have the same meaning but different naming (such as the base field `idx`) or other properties, we can override `field_config` using `vectorbtpro.records.decorators.override_field_config`. It will look for configs of all base classes and merge our config on top of them. This preserves any base class property that is not explicitly listed in our config. ```pycon >>> from vectorbtpro.records.decorators import override_field_config >>> my_dt = np.dtype([ ... ('my_id', int_), ... ('my_col', int_), ... ('my_idx', int_) ... ]) >>> my_fields_config = dict( ... dtype=my_dt, ... settings=dict( ... id=dict(name='my_id'), ... col=dict(name='my_col'), ... idx=dict(name='my_idx') ... ) ... ) >>> @override_field_config(my_fields_config) ... class MyRecords(vbt.Records): ... pass >>> records_arr = np.array([ ... (0, 0, 0), ... (1, 0, 1), ... (0, 1, 0), ... (1, 1, 1) ... ], dtype=my_dt) >>> wrapper = vbt.ArrayWrapper(index=['x', 'y'], ... columns=['a', 'b'], ndim=2, freq='1 day') >>> my_records = MyRecords(wrapper, records_arr) >>> my_records.id_arr array([0, 1, 0, 1]) >>> my_records.col_arr array([0, 0, 1, 1]) >>> my_records.idx_arr array([0, 1, 0, 1]) ``` Alternatively, we can override the `_field_config` class attribute. ```pycon >>> @override_field_config ... class MyRecords(vbt.Records): ... _field_config = dict( ... dtype=my_dt, ... settings=dict( ... id=dict(name='my_id'), ... idx=dict(name='my_idx'), ... col=dict(name='my_col') ... ) ... ) ``` !!! note Don't forget to decorate the class with `@override_field_config` to inherit configs from base classes. You can stop inheritance by not decorating or passing `merge_configs=False` to the decorator. """ import inspect import string from collections import defaultdict import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.reshaping import to_1d_array, index_to_series, index_to_frame from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.records import nb from vectorbtpro.records.col_mapper import ColumnMapper from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import get_dict_attr from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import cached_method, hybrid_method from vectorbtpro.utils.random_ import set_seed_nb from vectorbtpro.utils.template import Sub __all__ = [ "Records", ] __pdoc__ = {} RecordsT = tp.TypeVar("RecordsT", bound="Records") class MetaRecords(type(Analyzable)): """Metaclass for `Records`.""" @property def field_config(cls) -> Config: """Field config.""" return cls._field_config class Records(Analyzable, metaclass=MetaRecords): """Wraps the actual records array (such as trades) and exposes methods for mapping it to some array of values (such as PnL of each trade). Args: wrapper (ArrayWrapper): Array wrapper. See `vectorbtpro.base.wrapping.ArrayWrapper`. records_arr (array_like): A structured NumPy array of records. Must have the fields `id` (record index) and `col` (column index). col_mapper (ColumnMapper): Column mapper if already known. !!! note It depends on `records_arr`, so make sure to invalidate `col_mapper` upon creating a `Records` instance with a modified `records_arr`. `Records.replace` does it automatically. **kwargs: Custom keyword arguments passed to the config. Useful if any subclass wants to extend the config. """ _writeable_attrs: tp.WriteableAttrs = {"_field_config"} _field_config: tp.ClassVar[Config] = HybridConfig( dict( dtype=None, settings=dict( id=dict(name="id", title="Id", mapping="ids"), col=dict(name="col", title="Column", mapping="columns", as_customdata=False), idx=dict(name="idx", title="Index", mapping="index"), ), ) ) @property def field_config(self) -> Config: """Field config of `${cls_name}`. ```python ${field_config} ``` Returns `${cls_name}._field_config`, which gets (hybrid-) copied upon creation of each instance. Thus, changing this config won't affect the class. To change fields, you can either change the config in-place, override this property, or overwrite the instance variable `${cls_name}._field_config`. """ return self._field_config @classmethod def row_stack_records_arrs(cls, *objs: tp.MaybeTuple[tp.RecordArray], **kwargs) -> tp.RecordArray: """Stack multiple record arrays along rows.""" if len(objs) == 1: objs = objs[0] objs = list(objs) records_arrs = [] for col in range(kwargs["wrapper"].shape_2d[1]): n_rows_sum = 0 from_id = defaultdict(int) for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: col_records = None set_columns = False if col > 0 and obj.wrapper.shape_2d[1] == 1: col_records = obj.records_arr[col_idxs] set_columns = True elif col_lens[col] > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens col_records = obj.records_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]] if col_records is not None: col_records = col_records.copy() for field in obj.values.dtype.names: field_mapping = cls.field_config.get("settings", {}).get(field, {}).get("mapping", None) if isinstance(field_mapping, str) and field_mapping == "columns" and set_columns: col_records[field][:] = col elif isinstance(field_mapping, str) and field_mapping == "index": col_records[field][:] += n_rows_sum elif isinstance(field_mapping, str) and field_mapping == "ids": col_records[field][:] += from_id[field] from_id[field] = col_records[field].max() + 1 records_arrs.append(col_records) n_rows_sum += obj.wrapper.shape_2d[0] if len(records_arrs) == 0: return np.array([], dtype=objs[0].values.dtype) return np.concatenate(records_arrs) @classmethod def get_row_stack_record_indices(cls, *objs: tp.MaybeTuple[tp.RecordArray], **kwargs) -> tp.Array1d: """Get the indices that map concatenated record arrays into the row-stacked record array.""" if len(objs) == 1: objs = objs[0] objs = list(objs) record_indices = [] cum_n_rows_sum = [] for i in range(len(objs)): if i == 0: cum_n_rows_sum.append(0) else: cum_n_rows_sum.append(cum_n_rows_sum[-1] + len(objs[i - 1].values)) for col in range(kwargs["wrapper"].shape_2d[1]): for i, obj in enumerate(objs): col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: if col > 0 and obj.wrapper.shape_2d[1] == 1: _record_indices = col_idxs + cum_n_rows_sum[i] record_indices.append(_record_indices) elif col_lens[col] > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens _record_indices = col_idxs[col_start_idxs[col] : col_end_idxs[col]] + cum_n_rows_sum[i] record_indices.append(_record_indices) if len(record_indices) == 0: return np.array([], dtype=int_) return np.concatenate(record_indices) @hybrid_method def row_stack( cls_or_self: tp.MaybeType[RecordsT], *objs: tp.MaybeTuple[RecordsT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> RecordsT: """Stack multiple `Records` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers and `Records.row_stack_records_arrs` to stack the record arrays. !!! note Will produce a column-sorted array.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Records): raise TypeError("Each object to be merged must be an instance of Records") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) if "col_mapper" not in kwargs: kwargs["col_mapper"] = ColumnMapper.row_stack( *[obj.col_mapper for obj in objs], wrapper=kwargs["wrapper"], ) if "records_arr" not in kwargs: kwargs["records_arr"] = cls.row_stack_records_arrs(*objs, **kwargs) kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @classmethod def column_stack_records_arrs( cls, *objs: tp.MaybeTuple[tp.RecordArray], get_indexer_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.RecordArray: """Stack multiple record arrays along columns.""" if len(objs) == 1: objs = objs[0] objs = list(objs) if get_indexer_kwargs is None: get_indexer_kwargs = {} records_arrs = [] col_sum = 0 for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens for col in range(len(col_lens)): if col_lens[col] > 0: col_records = obj.records_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]] col_records = col_records.copy() for field in obj.values.dtype.names: field_mapping = cls.field_config.get("settings", {}).get(field, {}).get("mapping", None) if isinstance(field_mapping, str) and field_mapping == "columns": col_records[field][:] += col_sum elif isinstance(field_mapping, str) and field_mapping == "index": old_idxs = col_records[field] if not obj.wrapper.index.equals(kwargs["wrapper"].index): new_idxs = kwargs["wrapper"].index.get_indexer( obj.wrapper.index[old_idxs], **get_indexer_kwargs, ) else: new_idxs = old_idxs col_records[field][:] = new_idxs records_arrs.append(col_records) col_sum += obj.wrapper.shape_2d[1] if len(records_arrs) == 0: return np.array([], dtype=objs[0].values.dtype) return np.concatenate(records_arrs) @classmethod def get_column_stack_record_indices(cls, *objs: tp.MaybeTuple[tp.RecordArray], **kwargs) -> tp.Array1d: """Get the indices that map concatenated record arrays into the column-stacked record array.""" if len(objs) == 1: objs = objs[0] objs = list(objs) record_indices = [] cum_n_rows_sum = [] for i in range(len(objs)): if i == 0: cum_n_rows_sum.append(0) else: cum_n_rows_sum.append(cum_n_rows_sum[-1] + len(objs[i - 1].values)) for i, obj in enumerate(objs): col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens for col in range(len(col_lens)): if col_lens[col] > 0: _record_indices = col_idxs[col_start_idxs[col] : col_end_idxs[col]] + cum_n_rows_sum[i] record_indices.append(_record_indices) if len(record_indices) == 0: return np.array([], dtype=int_) return np.concatenate(record_indices) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[RecordsT], *objs: tp.MaybeTuple[RecordsT], wrapper_kwargs: tp.KwargsLike = None, get_indexer_kwargs: tp.KwargsLike = None, **kwargs, ) -> RecordsT: """Stack multiple `Records` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers and `Records.column_stack_records_arrs` to stack the record arrays. `get_indexer_kwargs` are passed to [pandas.Index.get_indexer](https://pandas.pydata.org/docs/reference/api/pandas.Index.get_indexer.html) to translate old indices to new ones after the reindexing operation. !!! note Will produce a column-sorted array.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, Records): raise TypeError("Each object to be merged must be an instance of Records") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], **wrapper_kwargs, ) if "col_mapper" not in kwargs: kwargs["col_mapper"] = ColumnMapper.column_stack( *[obj.col_mapper for obj in objs], wrapper=kwargs["wrapper"], ) if "records_arr" not in kwargs: kwargs["records_arr"] = cls.column_stack_records_arrs( *objs, get_indexer_kwargs=get_indexer_kwargs, **kwargs, ) kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, records_arr: tp.RecordArray, col_mapper: tp.Optional[ColumnMapper] = None, **kwargs, ) -> None: # Check fields records_arr = np.asarray(records_arr) checks.assert_not_none(records_arr.dtype.fields) field_names = {dct.get("name", field_name) for field_name, dct in self.field_config.get("settings", {}).items()} dtype = self.field_config.get("dtype", None) if dtype is not None: for field in dtype.names: if field not in records_arr.dtype.names: if field not in field_names: raise TypeError(f"Field '{field}' from {dtype} cannot be found in records or config") if col_mapper is None: col_mapper = ColumnMapper(wrapper, records_arr[self.get_field_name("col")]) Analyzable.__init__(self, wrapper, records_arr=records_arr, col_mapper=col_mapper, **kwargs) self._records_arr = records_arr self._col_mapper = col_mapper # Only slices of rows can be selected self._range_only_select = True # Copy writeable attrs self._field_config = type(self)._field_config.copy() def replace(self: RecordsT, **kwargs) -> RecordsT: """See `vectorbtpro.utils.config.Configured.replace`. Also, makes sure that `Records.col_mapper` is not passed to the new instance.""" if self.config.get("col_mapper", None) is not None: if "wrapper" in kwargs: if self.wrapper is not kwargs.get("wrapper"): kwargs["col_mapper"] = None if "records_arr" in kwargs: if self.records_arr is not kwargs.get("records_arr"): kwargs["col_mapper"] = None return Analyzable.replace(self, **kwargs) def select_cols( self, col_idxs: tp.MaybeIndexArray, jitted: tp.JittedOption = None, ) -> tp.Tuple[tp.Array1d, tp.RecordArray]: """Select columns. Returns indices and new record array. Automatically decides whether to use column lengths or column map.""" if len(self.values) == 0: return np.arange(len(self.values)), self.values if isinstance(col_idxs, slice): if col_idxs.start is None and col_idxs.stop is None: return np.arange(len(self.values)), self.values col_idxs = np.arange(col_idxs.start, col_idxs.stop) if self.col_mapper.is_sorted(): func = jit_reg.resolve_option(nb.record_col_lens_select_nb, jitted) new_indices, new_records_arr = func(self.values, self.col_mapper.col_lens, to_1d_array(col_idxs)) # faster else: func = jit_reg.resolve_option(nb.record_col_map_select_nb, jitted) new_indices, new_records_arr = func( self.values, self.col_mapper.col_map, to_1d_array(col_idxs) ) # more flexible return new_indices, new_records_arr def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict: """Perform indexing on `Records` and return metadata. By default, all fields that are mapped to index are indexed. To avoid indexing on some fields, set their setting `noindex` to True.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta( *args, column_only_select=self.column_only_select, range_only_select=self.range_only_select, group_select=self.group_select, **kwargs, ) if self.get_field_setting("col", "group_indexing", False): new_indices, new_records_arr = self.select_cols(wrapper_meta["group_idxs"]) else: new_indices, new_records_arr = self.select_cols(wrapper_meta["col_idxs"]) if wrapper_meta["rows_changed"]: row_idxs = wrapper_meta["row_idxs"] index_fields = [] all_index_fields = [] for field in new_records_arr.dtype.names: field_mapping = self.get_field_mapping(field) noindex = self.get_field_setting(field, "noindex", False) if isinstance(field_mapping, str) and field_mapping == "index": all_index_fields.append(field) if not noindex: index_fields.append(field) if len(index_fields) > 0: masks = [] for field in index_fields: field_arr = new_records_arr[field] masks.append((field_arr >= row_idxs.start) & (field_arr < row_idxs.stop)) mask = np.array(masks).all(axis=0) new_indices = new_indices[mask] new_records_arr = new_records_arr[mask] for field in all_index_fields: new_records_arr[field] = new_records_arr[field] - row_idxs.start return dict( wrapper_meta=wrapper_meta, new_indices=new_indices, new_records_arr=new_records_arr, ) def indexing_func(self: RecordsT, *args, records_meta: tp.DictLike = None, **kwargs) -> RecordsT: """Perform indexing on `Records`.""" if records_meta is None: records_meta = self.indexing_func_meta(*args, **kwargs) return self.replace( wrapper=records_meta["wrapper_meta"]["new_wrapper"], records_arr=records_meta["new_records_arr"], ) def resample_records_arr(self, resampler: tp.Union[Resampler, tp.PandasResampler]) -> tp.RecordArray: """Perform resampling on the record array.""" if isinstance(resampler, Resampler): _resampler = resampler else: _resampler = Resampler.from_pd_resampler(resampler) new_records_arr = self.records_arr.copy() for field_name in self.values.dtype.names: field_mapping = self.get_field_mapping(field_name) if isinstance(field_mapping, str) and field_mapping == "index": index_map = _resampler.map_to_target_index(return_index=False) new_records_arr[field_name] = index_map[new_records_arr[field_name]] return new_records_arr def resample_meta(self: RecordsT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict: """Perform resampling on `Records` and return metadata.""" if wrapper_meta is None: wrapper_meta = self.wrapper.resample_meta(*args, **kwargs) new_records_arr = self.resample_records_arr(wrapper_meta["resampler"]) return dict(wrapper_meta=wrapper_meta, new_records_arr=new_records_arr) def resample(self: RecordsT, *args, records_meta: tp.DictLike = None, **kwargs) -> RecordsT: """Perform resampling on `Records`.""" if records_meta is None: records_meta = self.resample_meta(*args, **kwargs) return self.replace( wrapper=records_meta["wrapper_meta"]["new_wrapper"], records_arr=records_meta["new_records_arr"], ) @property def records_arr(self) -> tp.RecordArray: """Records array.""" return self._records_arr @property def values(self) -> tp.RecordArray: """Records array.""" return self.records_arr def __len__(self) -> int: return len(self.values) @property def records(self) -> tp.Frame: """Records.""" return pd.DataFrame.from_records(self.values) @property def recarray(self) -> tp.RecArray: """Records with field access using attributes.""" return self.values.view(np.recarray) @property def col_mapper(self) -> ColumnMapper: """Column mapper. See `vectorbtpro.records.col_mapper.ColumnMapper`.""" return self._col_mapper @property def field_names(self) -> tp.List[str]: """Field names.""" return list(self.values.dtype.fields.keys()) def to_readable(self, expand_columns: bool = False) -> tp.Frame: """Get records in a human-readable format.""" new_columns = list() field_settings = self.field_config.get("settings", {}) for field_name in self.field_names: if field_name in field_settings: dct = field_settings[field_name] if dct.get("ignore", False): continue field_name = dct.get("name", field_name) if "title" in dct: title = dct["title"] else: title = field_name if "mapping" in dct: if isinstance(dct["mapping"], str) and dct["mapping"] == "index": new_columns.append(pd.Series(self.get_map_field_to_index(field_name), name=title)) elif isinstance(dct["mapping"], str) and dct["mapping"] == "columns": column_index = self.get_map_field_to_columns(field_name) if expand_columns and isinstance(column_index, pd.MultiIndex): column_frame = index_to_frame(column_index, reset_index=True) new_columns.append(column_frame.add_prefix(f"{title}: ")) else: column_sr = index_to_series(column_index, reset_index=True) if expand_columns and self.wrapper.ndim == 2 and column_sr.name is not None: new_columns.append(column_sr.rename(f"{title}: {column_sr.name}")) else: new_columns.append(column_sr.rename(title)) else: new_columns.append(pd.Series(self.get_apply_mapping_arr(field_name), name=title)) else: new_columns.append(pd.Series(self.values[field_name], name=title)) else: new_columns.append(pd.Series(self.values[field_name], name=field_name)) records_readable = pd.concat(new_columns, axis=1) if all([isinstance(col, tuple) for col in records_readable.columns]): records_readable.columns = pd.MultiIndex.from_tuples(records_readable.columns) return records_readable @property def records_readable(self) -> tp.Frame: """`Records.to_readable` with default arguments.""" return self.to_readable() readable = records_readable def get_field_setting(self, field: str, setting: str, default: tp.Any = None) -> tp.Any: """Get any setting of the field. Uses `Records.field_config`.""" return self.field_config.get("settings", {}).get(field, {}).get(setting, default) def get_field_name(self, field: str) -> str: """Get the name of the field. Uses `Records.field_config`..""" return self.get_field_setting(field, "name", field) def get_field_title(self, field: str) -> str: """Get the title of the field. Uses `Records.field_config`.""" return self.get_field_setting(field, "title", field) def get_field_mapping(self, field: str) -> tp.Optional[tp.MappingLike]: """Get the mapping of the field. Uses `Records.field_config`.""" return self.get_field_setting(field, "mapping", None) def get_field_arr(self, field: str, copy: bool = False) -> tp.Array1d: """Get the array of the field. Uses `Records.field_config`.""" out = self.values[self.get_field_name(field)] if copy: out = out.copy() return out def get_map_field(self, field: str, **kwargs) -> MappedArray: """Get the mapped array of the field. Uses `Records.field_config`.""" mapping = self.get_field_mapping(field) if isinstance(mapping, str) and mapping == "ids": mapping = None return self.map_field(self.get_field_name(field), mapping=mapping, **kwargs) def get_map_field_to_index(self, field: str, minus_one_to_zero: bool = False, **kwargs) -> tp.Index: """Get the mapped array on the field, with index applied. Uses `Records.field_config`.""" return self.get_map_field(field, **kwargs).to_index(minus_one_to_zero=minus_one_to_zero) def get_map_field_to_columns(self, field: str, **kwargs) -> tp.Index: """Get the mapped array on the field, with columns applied. Uses `Records.field_config`.""" return self.get_map_field(field, **kwargs).to_columns() def get_apply_mapping_arr(self, field: str, mapping_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Array1d: """Get the mapped array on the field, with mapping applied. Uses `Records.field_config`.""" mapping = self.get_field_mapping(field) if isinstance(mapping, str) and mapping == "index": return self.get_map_field_to_index(field, **kwargs).values if isinstance(mapping, str) and mapping == "columns": return self.get_map_field_to_columns(field, **kwargs).values return self.get_map_field(field, **kwargs).apply_mapping(mapping_kwargs=mapping_kwargs).values def get_apply_mapping_str_arr(self, field: str, mapping_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Array1d: """Get the mapped array on the field, with mapping applied and stringified. Uses `Records.field_config`.""" mapping = self.get_field_mapping(field) if isinstance(mapping, str) and mapping == "index": return self.get_map_field_to_index(field, **kwargs).astype(str).values if isinstance(mapping, str) and mapping == "columns": return self.get_map_field_to_columns(field, **kwargs).astype(str).values return self.get_map_field(field, **kwargs).apply_mapping(mapping_kwargs=mapping_kwargs).values.astype(str) @property def id_arr(self) -> tp.Array1d: """Get id array.""" return self.values[self.get_field_name("id")] @property def col_arr(self) -> tp.Array1d: """Get column array.""" return self.values[self.get_field_name("col")] @property def idx_arr(self) -> tp.Optional[tp.Array1d]: """Get index array.""" idx_field_name = self.get_field_name("idx") if idx_field_name is None: return None return self.values[idx_field_name] # ############# Sorting ############# # @cached_method def is_sorted(self, incl_id: bool = False, jitted: tp.JittedOption = None) -> bool: """Check whether records are sorted.""" if incl_id: func = jit_reg.resolve_option(nb.is_col_id_sorted_nb, jitted) return func(self.col_arr, self.id_arr) func = jit_reg.resolve_option(nb.is_col_sorted_nb, jitted) return func(self.col_arr) def sort(self: RecordsT, incl_id: bool = False, group_by: tp.GroupByLike = None, **kwargs) -> RecordsT: """Sort records by columns (primary) and ids (secondary, optional). !!! note Sorting is expensive. A better approach is to append records already in the correct order.""" if self.is_sorted(incl_id=incl_id): return self.replace(**kwargs).regroup(group_by) if incl_id: ind = np.lexsort((self.id_arr, self.col_arr)) # expensive! else: ind = np.argsort(self.col_arr) return self.replace(records_arr=self.values[ind], **kwargs).regroup(group_by) # ############# Filtering ############# # def apply_mask(self: RecordsT, mask: tp.Array1d, group_by: tp.GroupByLike = None, **kwargs) -> RecordsT: """Return a new class instance, filtered by mask.""" mask_indices = np.flatnonzero(mask) return self.replace(records_arr=np.take(self.values, mask_indices), **kwargs).regroup(group_by) def first_n( self: RecordsT, n: int, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> RecordsT: """Return the first N records in each column.""" col_map = self.col_mapper.get_col_map(group_by=False) func = jit_reg.resolve_option(nb.first_n_nb, jitted) func = ch_reg.resolve_option(func, chunked) return self.apply_mask(func(col_map, n), **kwargs) def last_n( self: RecordsT, n: int, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> RecordsT: """Return the last N records in each column.""" col_map = self.col_mapper.get_col_map(group_by=False) func = jit_reg.resolve_option(nb.last_n_nb, jitted) func = ch_reg.resolve_option(func, chunked) return self.apply_mask(func(col_map, n), **kwargs) def random_n( self: RecordsT, n: int, seed: tp.Optional[int] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> RecordsT: """Return random N records in each column.""" if seed is not None: set_seed_nb(seed) col_map = self.col_mapper.get_col_map(group_by=False) func = jit_reg.resolve_option(nb.random_n_nb, jitted) func = ch_reg.resolve_option(func, chunked) return self.apply_mask(func(col_map, n), **kwargs) # ############# Mapping ############# # def map_array( self, a: tp.ArrayLike, idx_arr: tp.Union[None, str, tp.Array1d] = None, mapping: tp.Optional[tp.MappingLike] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> MappedArray: """Convert array to mapped array. The length of the array must match that of the records.""" if not isinstance(a, np.ndarray): a = np.asarray(a) checks.assert_shape_equal(a, self.values) if idx_arr is None: idx_arr = self.idx_arr elif isinstance(idx_arr, str): idx_arr = self.get_field_arr(idx_arr) return MappedArray( self.wrapper, a, self.col_arr, id_arr=self.id_arr, idx_arr=idx_arr, mapping=mapping, col_mapper=self.col_mapper, **kwargs, ).regroup(group_by) def map_field(self, field: str, **kwargs) -> MappedArray: """Convert field to mapped array. `**kwargs` are passed to `Records.map_array`.""" mapped_arr = self.values[field] return self.map_array(mapped_arr, **kwargs) @hybrid_method def map( cls_or_self, map_func_nb: tp.Union[tp.RecordsMapFunc, tp.RecordsMapMetaFunc], *args, dtype: tp.Optional[tp.DTypeLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, col_mapper: tp.Optional[ColumnMapper] = None, **kwargs, ) -> MappedArray: """Map each record to a scalar value. Returns mapped array. See `vectorbtpro.records.nb.map_records_nb`. For details on the meta version, see `vectorbtpro.records.nb.map_records_meta_nb`. `**kwargs` are passed to `Records.map_array`.""" if isinstance(cls_or_self, type): checks.assert_not_none(col_mapper, arg_name="col_mapper") func = jit_reg.resolve_option(nb.map_records_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) mapped_arr = func(len(col_mapper.col_arr), map_func_nb, *args) mapped_arr = np.asarray(mapped_arr, dtype=dtype) return MappedArray(col_mapper.wrapper, mapped_arr, col_mapper.col_arr, col_mapper=col_mapper, **kwargs) else: func = jit_reg.resolve_option(nb.map_records_nb, jitted) func = ch_reg.resolve_option(func, chunked) mapped_arr = func(cls_or_self.values, map_func_nb, *args) mapped_arr = np.asarray(mapped_arr, dtype=dtype) return cls_or_self.map_array(mapped_arr, **kwargs) @hybrid_method def apply( cls_or_self, apply_func_nb: tp.Union[tp.ApplyFunc, tp.ApplyMetaFunc], *args, group_by: tp.GroupByLike = None, apply_per_group: bool = False, dtype: tp.Optional[tp.DTypeLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, col_mapper: tp.Optional[ColumnMapper] = None, **kwargs, ) -> MappedArray: """Apply function on records per column/group. Returns mapped array. Applies per group if `apply_per_group` is True. See `vectorbtpro.records.nb.apply_nb`. For details on the meta version, see `vectorbtpro.records.nb.apply_meta_nb`. `**kwargs` are passed to `Records.map_array`.""" if isinstance(cls_or_self, type): checks.assert_not_none(col_mapper, arg_name="col_mapper") col_map = col_mapper.get_col_map(group_by=group_by if apply_per_group else False) func = jit_reg.resolve_option(nb.apply_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) mapped_arr = func(len(col_mapper.col_arr), col_map, apply_func_nb, *args) mapped_arr = np.asarray(mapped_arr, dtype=dtype) return MappedArray(col_mapper.wrapper, mapped_arr, col_mapper.col_arr, col_mapper=col_mapper, **kwargs) else: col_map = cls_or_self.col_mapper.get_col_map(group_by=group_by if apply_per_group else False) func = jit_reg.resolve_option(nb.apply_nb, jitted) func = ch_reg.resolve_option(func, chunked) mapped_arr = func(cls_or_self.values, col_map, apply_func_nb, *args) mapped_arr = np.asarray(mapped_arr, dtype=dtype) return cls_or_self.map_array(mapped_arr, group_by=group_by, **kwargs) # ############# Masking ############# # def get_pd_mask( self, idx_arr: tp.Union[None, str, tp.Array1d] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get mask in form of a Series/DataFrame from row and column indices.""" if idx_arr is None: if self.idx_arr is None: raise ValueError("Must pass idx_arr") idx_arr = self.idx_arr elif isinstance(idx_arr, str): idx_arr = self.get_field_arr(idx_arr) col_arr = self.col_mapper.get_col_arr(group_by=group_by) target_shape = self.wrapper.get_shape_2d(group_by=group_by) out_arr = np.full(target_shape, False) out_arr[idx_arr, col_arr] = True return self.wrapper.wrap(out_arr, group_by=group_by, **resolve_dict(wrap_kwargs)) @property def pd_mask(self) -> tp.SeriesFrame: """`MappedArray.get_pd_mask` with default arguments.""" return self.get_pd_mask() # ############# Reducing ############# # @cached_method def count(self, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries: """Get count by column.""" wrap_kwargs = merge_dicts(dict(name_or_index="count"), wrap_kwargs) return self.wrapper.wrap_reduced( self.col_mapper.get_col_map(group_by=group_by)[1], group_by=group_by, **wrap_kwargs, ) # ############# Conflicts ############# # @cached_method def has_conflicts(self, **kwargs) -> bool: """See `vectorbtpro.records.mapped_array.MappedArray.has_conflicts`.""" return self.get_map_field("col").has_conflicts(**kwargs) def coverage_map(self, **kwargs) -> tp.SeriesFrame: """See `vectorbtpro.records.mapped_array.MappedArray.coverage_map`.""" return self.get_map_field("col").coverage_map(**kwargs) # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `Records.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.records`.""" from vectorbtpro._settings import settings records_stats_cfg = settings["records"]["stats"] return merge_dicts(Analyzable.stats_defaults.__get__(self), records_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), count=dict(title="Count", calc_func="count", tags="records"), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def prepare_customdata( self, incl_fields: tp.Optional[tp.Sequence[str]] = None, excl_fields: tp.Optional[tp.Sequence[str]] = None, append_info: tp.Optional[tp.Sequence[tp.Tuple]] = None, mask: tp.Optional[tp.Array1d] = None, ) -> tp.Tuple[tp.Array2d, str]: """Prepare customdata and hoverinfo for Plotly. Will display all fields in the data type or only those in `incl_fields`, unless any of them has the field config setting `as_customdata` disabled, or it's listed in `excl_fields`. Additionally, you can define `hovertemplate` in the field config such as by using `vectorbtpro.utils.template.Sub` where `title` is substituted by the title and `index` is substituted by (final) index in the customdata. If provided as a string, will be wrapped with `vectorbtpro.utils.template.Sub`. Defaults to "$title: %{{customdata[$index]}}". Mapped fields will be stringified automatically. To append one or more custom arrays, provide `append_info` as a list of tuples, each consisting of a 1-dim NumPy array, title, and optionally hoverinfo. If the array's data type is `object`, will treat it as strings, otherwise as numbers.""" customdata_info = [] if incl_fields is not None: iterate_over_names = incl_fields else: iterate_over_names = self.field_config.get("dtype").names for field in iterate_over_names: if excl_fields is not None and field in excl_fields: continue field_as_customdata = self.get_field_setting(field, "as_customdata", True) if field_as_customdata: numeric_customdata = self.get_field_setting(field, "mapping", None) if numeric_customdata is not None: field_arr = self.get_apply_mapping_str_arr(field) field_hovertemplate = self.get_field_setting( field, "hovertemplate", "$title: %{customdata[$index]}", ) else: field_arr = self.get_apply_mapping_arr(field) field_hovertemplate = self.get_field_setting( field, "hovertemplate", "$title: %{customdata[$index]:,}", ) if isinstance(field_hovertemplate, str): field_hovertemplate = Sub(field_hovertemplate) field_title = self.get_field_title(field) customdata_info.append((field_arr, field_title, field_hovertemplate)) if append_info is not None: for info in append_info: checks.assert_instance_of(info, tuple) if len(info) == 2: if info[0].dtype == object: info += ("$title: %{customdata[$index]}",) else: info += ("$title: %{customdata[$index]:,}",) if isinstance(info[2], str): info = (info[0], info[1], Sub(info[2])) customdata_info.append(info) customdata = [] hovertemplate = [] for i in range(len(customdata_info)): if mask is not None: customdata.append(customdata_info[i][0][mask]) else: customdata.append(customdata_info[i][0]) _hovertemplate = customdata_info[i][2].substitute(dict(title=customdata_info[i][1], index=i)) if not _hovertemplate.startswith("
"): _hovertemplate = "
" + _hovertemplate hovertemplate.append(_hovertemplate) return np.stack(customdata, axis=1), "\n".join(hovertemplate) @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `Records.plots`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.records`.""" from vectorbtpro._settings import settings records_plots_cfg = settings["records"]["plots"] return merge_dicts(Analyzable.plots_defaults.__get__(self), records_plots_cfg) @property def subplots(self) -> Config: return self._subplots # ############# Docs ############# # @classmethod def build_field_config_doc(cls, source_cls: tp.Optional[type] = None) -> str: """Build field config documentation.""" if source_cls is None: source_cls = Records return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "field_config").__doc__)).substitute( {"field_config": cls.field_config.prettify(), "cls_name": cls.__name__}, ) @classmethod def override_field_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None: """Call this method on each subclass that overrides `Records.field_config`.""" __pdoc__[cls.__name__ + ".field_config"] = cls.build_field_config_doc(source_cls=source_cls) Records.override_field_config_doc(__pdoc__) Records.override_metrics_doc(__pdoc__) Records.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Extensions for chunking records and mapped arrays.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base.chunking import GroupLensMapper, GroupIdxsMapper from vectorbtpro.utils.chunking import ChunkMeta, ChunkMapper from vectorbtpro.utils.parsing import Regex __all__ = [] col_lens_mapper = GroupLensMapper(arg_query=Regex(r"(col_lens|col_map)")) """Default instance of `vectorbtpro.base.chunking.GroupLensMapper` for per-column lengths.""" col_idxs_mapper = GroupIdxsMapper(arg_query="col_map") """Default instance of `vectorbtpro.base.chunking.GroupIdxsMapper` for per-column indices.""" def fix_field_in_records( record_arrays: tp.List[tp.RecordArray], chunk_meta: tp.Iterable[ChunkMeta], ann_args: tp.Optional[tp.AnnArgs] = None, mapper: tp.Optional[ChunkMapper] = None, field: str = "col", ) -> None: """Fix a field of the record array in each chunk.""" for _chunk_meta in chunk_meta: if mapper is None: record_arrays[_chunk_meta.idx][field] += _chunk_meta.start else: _chunk_meta_mapped = mapper.map(_chunk_meta, ann_args=ann_args) record_arrays[_chunk_meta.idx][field] += _chunk_meta_mapped.start def merge_records( results: tp.List[tp.RecordArray], chunk_meta: tp.Iterable[ChunkMeta], ann_args: tp.Optional[tp.AnnArgs] = None, mapper: tp.Optional[ChunkMapper] = None, ) -> tp.RecordArray: """Merge chunks of record arrays. Mapper is only applied on the column field.""" if "col" in results[0].dtype.fields: fix_field_in_records(results, chunk_meta, ann_args=ann_args, mapper=mapper, field="col") if "group" in results[0].dtype.fields: fix_field_in_records(results, chunk_meta, field="group") return np.concatenate(results) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class for mapping column arrays.""" import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.base.grouping import nb as grouping_nb from vectorbtpro.base.reshaping import to_1d_array from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.records import nb from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils.decorators import hybrid_method, cached_property, cached_method __all__ = [ "ColumnMapper", ] ColumnMapperT = tp.TypeVar("ColumnMapperT", bound="ColumnMapper") class ColumnMapper(Wrapping): """Used by `vectorbtpro.records.base.Records` and `vectorbtpro.records.mapped_array.MappedArray` classes to make use of column and group metadata.""" @hybrid_method def row_stack( cls_or_self: tp.MaybeType[ColumnMapperT], *objs: tp.MaybeTuple[ColumnMapperT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> ColumnMapperT: """Stack multiple `ColumnMapper` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers. !!! note Will produce a column-sorted array.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, ColumnMapper): raise TypeError("Each object to be merged must be an instance of ColumnMapper") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) if "col_arr" not in kwargs: col_arrs = [] for col in range(kwargs["wrapper"].shape_2d[1]): for obj in objs: col_idxs, col_lens = obj.col_map if len(col_idxs) > 0: if col > 0 and obj.wrapper.shape_2d[1] == 1: col_arrs.append(np.full(col_lens[0], col)) elif col_lens[col] > 0: col_arrs.append(np.full(col_lens[col], col)) kwargs["col_arr"] = np.concatenate(col_arrs) kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[ColumnMapperT], *objs: tp.MaybeTuple[ColumnMapperT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> ColumnMapperT: """Stack multiple `ColumnMapper` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers. !!! note Will produce a column-sorted array.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, ColumnMapper): raise TypeError("Each object to be merged must be an instance of ColumnMapper") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], **wrapper_kwargs, ) if "col_arr" not in kwargs: col_arrs = [] col_sum = 0 for obj in objs: col_idxs, col_lens = obj.col_map if len(col_idxs) > 0: col_arrs.append(obj.col_arr[col_idxs] + col_sum) col_sum += obj.wrapper.shape_2d[1] kwargs["col_arr"] = np.concatenate(col_arrs) kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__(self, wrapper: ArrayWrapper, col_arr: tp.Array1d, **kwargs) -> None: Wrapping.__init__(self, wrapper, col_arr=col_arr, **kwargs) self._col_arr = col_arr # Cannot select rows self._column_only_select = True def select_cols( self, col_idxs: tp.MaybeIndexArray, jitted: tp.JittedOption = None, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Select columns. Returns indices and new column array. Automatically decides whether to use column lengths or column map.""" if len(self.col_arr) == 0: return np.arange(len(self.col_arr)), self.col_arr if isinstance(col_idxs, slice): if col_idxs.start is None and col_idxs.stop is None: return np.arange(len(self.col_arr)), self.col_arr col_idxs = np.arange(col_idxs.start, col_idxs.stop) if self.is_sorted(): func = jit_reg.resolve_option(grouping_nb.group_lens_select_nb, jitted) new_indices, new_col_arr = func(self.col_lens, to_1d_array(col_idxs)) # faster else: func = jit_reg.resolve_option(grouping_nb.group_map_select_nb, jitted) new_indices, new_col_arr = func(self.col_map, to_1d_array(col_idxs)) # more flexible return new_indices, new_col_arr def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict: """Perform indexing on `ColumnMapper` and return metadata.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta( *args, column_only_select=self.column_only_select, group_select=self.group_select, **kwargs, ) new_indices, new_col_arr = self.select_cols(wrapper_meta["col_idxs"]) return dict( wrapper_meta=wrapper_meta, new_indices=new_indices, new_col_arr=new_col_arr, ) def indexing_func(self: ColumnMapperT, *args, col_mapper_meta: tp.DictLike = None, **kwargs) -> ColumnMapperT: """Perform indexing on `ColumnMapper`.""" if col_mapper_meta is None: col_mapper_meta = self.indexing_func_meta(*args, **kwargs) return self.replace( wrapper=col_mapper_meta["wrapper_meta"]["new_wrapper"], col_arr=col_mapper_meta["new_col_arr"], ) @property def col_arr(self) -> tp.Array1d: """Column array.""" return self._col_arr @cached_method(whitelist=True) def get_col_arr(self, group_by: tp.GroupByLike = None) -> tp.Array1d: """Get group-aware column array.""" group_arr = self.wrapper.grouper.get_groups(group_by=group_by) if group_arr is not None: col_arr = group_arr[self.col_arr] else: col_arr = self.col_arr return col_arr @cached_property(whitelist=True) def col_lens(self) -> tp.GroupLens: """Column lengths. Faster than `ColumnMapper.col_map` but only compatible with sorted columns.""" func = jit_reg.resolve_option(nb.col_lens_nb, None) return func(self.col_arr, len(self.wrapper.columns)) @cached_method(whitelist=True) def get_col_lens(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None) -> tp.GroupLens: """Get group-aware column lengths.""" if not self.wrapper.grouper.is_grouped(group_by=group_by): return self.col_lens col_arr = self.get_col_arr(group_by=group_by) columns = self.wrapper.get_columns(group_by=group_by) func = jit_reg.resolve_option(nb.col_lens_nb, jitted) return func(col_arr, len(columns)) @cached_property(whitelist=True) def col_map(self) -> tp.GroupMap: """Column map. More flexible than `ColumnMapper.col_lens`. More suited for mapped arrays.""" func = jit_reg.resolve_option(nb.col_map_nb, None) return func(self.col_arr, len(self.wrapper.columns)) @cached_method(whitelist=True) def get_col_map(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None) -> tp.GroupMap: """Get group-aware column map.""" if not self.wrapper.grouper.is_grouped(group_by=group_by): return self.col_map col_arr = self.get_col_arr(group_by=group_by) columns = self.wrapper.get_columns(group_by=group_by) func = jit_reg.resolve_option(nb.col_map_nb, jitted) return func(col_arr, len(columns)) @cached_method(whitelist=True) def is_sorted(self, jitted: tp.JittedOption = None) -> bool: """Check whether column array is sorted.""" func = jit_reg.resolve_option(nb.is_col_sorted_nb, jitted) return func(self.col_arr) @cached_property(whitelist=True) def new_id_arr(self) -> tp.Array1d: """Generate a new id array.""" func = jit_reg.resolve_option(nb.generate_ids_nb, None) return func(self.col_arr, self.wrapper.shape_2d[1]) @cached_method(whitelist=True) def get_new_id_arr(self, group_by: tp.GroupByLike = None) -> tp.Array1d: """Generate a new group-aware id array.""" group_arr = self.wrapper.grouper.get_groups(group_by=group_by) if group_arr is not None: col_arr = group_arr[self.col_arr] else: col_arr = self.col_arr columns = self.wrapper.get_columns(group_by=group_by) func = jit_reg.resolve_option(nb.generate_ids_nb, None) return func(col_arr, len(columns)) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Class decorators for records.""" import keyword from functools import partial from vectorbtpro import _typing as tp from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.utils import checks from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import cacheable_property, cached_property from vectorbtpro.utils.mapping import to_value_mapping __all__ = [] def override_field_config(config: Config, merge_configs: bool = True) -> tp.ClassWrapper: """Class decorator to override the field config of a class subclassing `vectorbtpro.records.base.Records`. Instead of overriding `_field_config` class attribute, you can pass `config` directly to this decorator. Disable `merge_configs` to not merge, which will effectively disable field inheritance.""" def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: checks.assert_subclass_of(cls, "Records") if merge_configs: new_config = merge_dicts(cls.field_config, config) else: new_config = config if not isinstance(new_config, Config): new_config = HybridConfig(new_config) setattr(cls, "_field_config", new_config) return cls return wrapper def attach_fields(*args, on_conflict: str = "raise") -> tp.FlexClassWrapper: """Class decorator to attach field properties in a `vectorbtpro.records.base.Records` class. Will extract `dtype` and other relevant information from `vectorbtpro.records.base.Records.field_config` and map its fields as properties. This behavior can be changed by using `config`. !!! note Make sure to run `attach_fields` after `override_field_config`. `config` must contain fields (keys) and dictionaries (values) with the following keys: * `attach`: Whether to attach the field property. Can be provided as a string to be used as a target attribute name. Defaults to True. * `defaults`: Dictionary with default keyword arguments for `vectorbtpro.records.base.Records.map_field`. Defaults to an empty dict. * `attach_filters`: Whether to attach filters based on the field's values. Can be provided as a dict to be used instead of the mapping (filter value -> target filter name). Defaults to False. If True, defaults to `mapping` in `vectorbtpro.records.base.Records.field_config`. * `filter_defaults`: Dictionary with default keyword arguments for `vectorbtpro.records.base.Records.apply_mask`. Can be provided by target filter name. Defaults to an empty dict. * `on_conflict`: Overrides global `on_conflict` for both field and filter properties. Any potential attribute name is prepared by placing underscores between capital letters and converting to the lower case. If an attribute with the same name already exists in the class but the name is not listed in the field config: * it will be overridden if `on_conflict` is 'override' * it will be ignored if `on_conflict` is 'ignore' * an error will be raised if `on_conflict` is 'raise' """ def wrapper(cls: tp.Type[tp.T], config: tp.DictLike = None) -> tp.Type[tp.T]: checks.assert_subclass_of(cls, "Records") dtype = cls.field_config.get("dtype", None) checks.assert_not_none(dtype.fields) if config is None: config = {} def _prepare_attr_name(attr_name: str) -> str: checks.assert_instance_of(attr_name, str) attr_name = attr_name.replace("NaN", "Nan") startswith_ = attr_name.startswith("_") new_attr_name = "" for i in range(len(attr_name)): if attr_name[i].isupper(): if i > 0 and attr_name[i - 1].islower(): new_attr_name += "_" new_attr_name += attr_name[i] attr_name = new_attr_name if not startswith_ and attr_name.startswith("_"): attr_name = attr_name[1:] attr_name = attr_name.lower() if keyword.iskeyword(attr_name): attr_name += "_" return attr_name.replace("__", "_") def _check_attr_name(attr_name, _on_conflict: str = on_conflict) -> None: if attr_name not in cls.field_config.get("settings", {}): # Consider only attributes that are not listed in the field config if hasattr(cls, attr_name): if _on_conflict.lower() == "raise": raise ValueError(f"An attribute with the name '{attr_name}' already exists in {cls}") if _on_conflict.lower() == "ignore": return if _on_conflict.lower() == "override": return raise ValueError(f"Value '{_on_conflict}' is invalid for on_conflict") if keyword.iskeyword(attr_name): raise ValueError(f"Name '{attr_name}' is a keyword and cannot be used as an attribute name") if dtype is not None: for field_name in dtype.names: settings = config.get(field_name, {}) attach = settings.get("attach", True) if not isinstance(attach, bool): target_name = attach attach = True else: target_name = field_name defaults = settings.get("defaults", None) if defaults is None: defaults = {} attach_filters = settings.get("attach_filters", False) filter_defaults = settings.get("filter_defaults", None) if filter_defaults is None: filter_defaults = {} _on_conflict = settings.get("on_conflict", on_conflict) if attach: target_name = _prepare_attr_name(target_name) _check_attr_name(target_name, _on_conflict) def new_prop( self, _field_name: str = field_name, _defaults: tp.KwargsLike = defaults, ) -> MappedArray: return self.get_map_field(_field_name, **_defaults) new_prop.__name__ = target_name new_prop.__module__ = cls.__module__ new_prop.__qualname__ = f"{cls.__name__}.{new_prop.__name__}" new_prop.__doc__ = f"Mapped array of the field `{field_name}`." setattr(cls, target_name, cached_property(new_prop)) if attach_filters: if isinstance(attach_filters, bool): if not attach_filters: continue mapping = cls.field_config.get("settings", {}).get(field_name, {}).get("mapping", None) else: mapping = attach_filters if mapping is None: raise ValueError(f"Field '{field_name}': Mapping is required to attach filters") mapping = to_value_mapping(mapping) for filter_value, target_filter_name in mapping.items(): if target_filter_name is None: continue if isinstance(attach_filters, bool): target_filter_name = field_name + "_" + target_filter_name target_filter_name = _prepare_attr_name(target_filter_name) _check_attr_name(target_filter_name, _on_conflict) if target_filter_name in filter_defaults: __filter_defaults = filter_defaults[target_filter_name] else: __filter_defaults = filter_defaults def new_filter_prop( self, _field_name: str = field_name, _filter_value: tp.Any = filter_value, _filter_defaults: tp.KwargsLike = __filter_defaults, ) -> MappedArray: filter_mask = self.get_field_arr(_field_name) == _filter_value return self.apply_mask(filter_mask, **_filter_defaults) new_filter_prop.__name__ = target_filter_name new_filter_prop.__module__ = cls.__module__ new_filter_prop.__qualname__ = f"{cls.__name__}.{new_filter_prop.__name__}" new_filter_prop.__doc__ = f"Records filtered by `{field_name} == {filter_value}`." setattr(cls, target_filter_name, cached_property(new_filter_prop)) return cls if len(args) == 0: return wrapper elif len(args) == 1: if isinstance(args[0], type): return wrapper(args[0]) return partial(wrapper, config=args[0]) elif len(args) == 2: return wrapper(args[0], config=args[1]) raise ValueError("Either class, config, class and config, or keyword arguments must be passed") def attach_shortcut_properties(config: Config) -> tp.ClassWrapper: """Class decorator to attach shortcut properties. `config` must contain target property names (keys) and settings (values) with the following keys: * `method_name`: Name of the source method. Defaults to the target name prepended with the prefix `get_`. * `obj_type`: Type of the returned object. Can be 'array' for 2-dim arrays, 'red_array' for 1-dim arrays, 'records' for record arrays, and 'mapped_array' for mapped arrays. Defaults to 'records'. * `group_aware`: Whether the returned object is aligned based on the current grouping. Defaults to True. * `method_kwargs`: Keyword arguments passed to the source method. Defaults to None. * `decorator`: Defaults to `vectorbtpro.utils.decorators.cached_property` for object types 'records' and 'red_array'. Otherwise, to `vectorbtpro.utils.decorators.cacheable_property`. * `decorator_kwargs`: Keyword arguments passed to the decorator. By default, includes options `obj_type` and `group_aware`. * `docstring`: Method docstring. The class must be a subclass of `vectorbtpro.records.base.Records`.""" def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]: checks.assert_subclass_of(cls, "Records") for target_name, settings in config.items(): if target_name.startswith("get_"): raise ValueError(f"Property names cannot have prefix 'get_' ('{target_name}')") method_name = settings.get("method_name", "get_" + target_name) obj_type = settings.get("obj_type", "records") group_by_aware = settings.get("group_by_aware", True) method_kwargs = settings.get("method_kwargs", None) method_kwargs = resolve_dict(method_kwargs) decorator = settings.get("decorator", None) if decorator is None: if obj_type in ("red_array", "records"): decorator = cached_property else: decorator = cacheable_property decorator_kwargs = merge_dicts( dict(obj_type=obj_type, group_by_aware=group_by_aware), settings.get("decorator_kwargs", None), ) docstring = settings.get("docstring", None) if docstring is None: if len(method_kwargs) == 0: docstring = f"`{cls.__name__}.{method_name}` with default arguments." else: docstring = f"`{cls.__name__}.{method_name}` with arguments `{method_kwargs}`." def new_prop(self, _method_name: str = method_name, _method_kwargs: tp.Kwargs = method_kwargs) -> tp.Any: return getattr(self, _method_name)(**_method_kwargs) new_prop.__name__ = target_name new_prop.__module__ = cls.__module__ new_prop.__qualname__ = f"{cls.__name__}.{new_prop.__name__}" new_prop.__doc__ = docstring setattr(cls, new_prop.__name__, decorator(new_prop, **decorator_kwargs)) return cls return wrapper # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base class for working with mapped arrays. This class takes the mapped array and the corresponding column and (optionally) index arrays, and offers features to directly process the mapped array without converting it to pandas; for example, to compute various statistics by column, such as standard deviation. Consider the following example: ```pycon >>> from vectorbtpro import * >>> a = np.array([10., 11., 12., 13., 14., 15., 16., 17., 18.]) >>> col_arr = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]) >>> idx_arr = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) >>> wrapper = vbt.ArrayWrapper(index=['x', 'y', 'z'], ... columns=['a', 'b', 'c'], ndim=2, freq='1 day') >>> ma = vbt.MappedArray(wrapper, a, col_arr, idx_arr=idx_arr) ``` ## Reducing Using `MappedArray`, we can then reduce by column as follows: * Use already provided reducers such as `MappedArray.mean`: ```pycon >>> ma.mean() a 11.0 b 14.0 c 17.0 dtype: float64 ``` * Use `MappedArray.to_pd` to map to pandas and then reduce manually (expensive): ```pycon >>> ma.to_pd().mean() a 11.0 b 14.0 c 17.0 dtype: float64 ``` * Use `MappedArray.reduce` to reduce using a custom function: ```pycon >>> # Reduce to a scalar >>> @njit ... def pow_mean_reduce_nb(a, pow): ... return np.mean(a ** pow) >>> ma.reduce(pow_mean_reduce_nb, 2) a 121.666667 b 196.666667 c 289.666667 dtype: float64 >>> # Reduce to an array >>> @njit ... def min_max_reduce_nb(a): ... return np.array([np.min(a), np.max(a)]) >>> ma.reduce(min_max_reduce_nb, returns_array=True, ... wrap_kwargs=dict(name_or_index=['min', 'max'])) a b c min 10.0 13.0 16.0 max 12.0 15.0 18.0 >>> # Reduce to an array of indices >>> @njit ... def idxmin_idxmax_reduce_nb(a): ... return np.array([np.argmin(a), np.argmax(a)]) >>> ma.reduce(idxmin_idxmax_reduce_nb, returns_array=True, ... returns_idx=True, wrap_kwargs=dict(name_or_index=['idxmin', 'idxmax'])) a b c idxmin x x x idxmax z z z >>> # Reduce using a meta function to combine multiple mapped arrays >>> @njit ... def mean_ratio_reduce_meta_nb(idxs, col, a, b): ... return np.mean(a[idxs]) / np.mean(b[idxs]) >>> vbt.MappedArray.reduce(mean_ratio_reduce_meta_nb, ... ma.values - 1, ma.values + 1, col_mapper=ma.col_mapper) a 0.833333 b 0.866667 c 0.888889 Name: reduce, dtype: float64 ``` ## Mapping Use `MappedArray.apply` to apply a function on each column/group: ```pycon >>> @njit ... def cumsum_apply_nb(a): ... return np.cumsum(a) >>> ma.apply(cumsum_apply_nb) >>> ma.apply(cumsum_apply_nb).values array([10., 21., 33., 13., 27., 42., 16., 33., 51.]) >>> group_by = np.array(['first', 'first', 'second']) >>> ma.apply(cumsum_apply_nb, group_by=group_by, apply_per_group=True).values array([10., 21., 33., 46., 60., 75., 16., 33., 51.]) >>> # Apply using a meta function >>> @njit ... def cumsum_apply_meta_nb(ridxs, col, a): ... return np.cumsum(a[ridxs]) >>> vbt.MappedArray.apply(cumsum_apply_meta_nb, ma.values, col_mapper=ma.col_mapper).values array([10., 21., 33., 13., 27., 42., 16., 33., 51.]) ``` Notice how cumsum resets at each column in the first example and at each group in the second example. ## Conversion We can unstack any `MappedArray` instance to pandas: * Given `idx_arr` was provided: ```pycon >>> ma.to_pd() a b c x 10.0 13.0 16.0 y 11.0 14.0 17.0 z 12.0 15.0 18.0 ``` !!! note Will throw a warning if there are multiple values pointing to the same position. * In case `group_by` was provided, index can be ignored, or there are position conflicts: ```pycon >>> ma.to_pd(group_by=np.array(['first', 'first', 'second']), ignore_index=True) first second 0 10.0 16.0 1 11.0 17.0 2 12.0 18.0 3 13.0 NaN 4 14.0 NaN 5 15.0 NaN ``` ## Resolving conflicts Sometimes, we may encounter multiple values for each index and column combination. In such case, we can use `MappedArray.reduce_segments` to aggregate "duplicate" elements. For example, let's sum up duplicate values per each index and column combination: ```pycon >>> ma_conf = ma.replace(idx_arr=np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])) >>> ma_conf.to_pd() UserWarning: Multiple values are pointing to the same position. Only the latest value is used. a b c x 12.0 NaN NaN y NaN 15.0 NaN z NaN NaN 18.0 >>> @njit ... def sum_reduce_nb(a): ... return np.sum(a) >>> ma_no_conf = ma_conf.reduce_segments( ... (ma_conf.idx_arr, ma_conf.col_arr), ... sum_reduce_nb ... ) >>> ma_no_conf.to_pd() a b c x 33.0 NaN NaN y NaN 42.0 NaN z NaN NaN 51.0 ``` ## Filtering Use `MappedArray.apply_mask` to filter elements per column/group: ```pycon >>> mask = [True, False, True, False, True, False, True, False, True] >>> filtered_ma = ma.apply_mask(mask) >>> filtered_ma.count() a 2 b 1 c 2 dtype: int64 >>> filtered_ma.id_arr array([0, 2, 4, 6, 8]) ``` ## Grouping One of the key features of `MappedArray` is that we can perform reducing operations on a group of columns as if they were a single column. Groups can be specified by `group_by`, which can be anything from positions or names of column levels, to a NumPy array with actual groups. There are multiple ways of define grouping: * When creating `MappedArray`, pass `group_by` to `vectorbtpro.base.wrapping.ArrayWrapper`: ```pycon >>> group_by = np.array(['first', 'first', 'second']) >>> grouped_wrapper = wrapper.replace(group_by=group_by) >>> grouped_ma = vbt.MappedArray(grouped_wrapper, a, col_arr, idx_arr=idx_arr) >>> grouped_ma.mean() first 12.5 second 17.0 dtype: float64 ``` * Regroup an existing `MappedArray`: ```pycon >>> ma.regroup(group_by).mean() first 12.5 second 17.0 dtype: float64 ``` * Pass `group_by` directly to the reducing method: ```pycon >>> ma.mean(group_by=group_by) first 12.5 second 17.0 dtype: float64 ``` By the same way we can disable or modify any existing grouping: ```pycon >>> grouped_ma.mean(group_by=False) a 11.0 b 14.0 c 17.0 dtype: float64 ``` !!! note Grouping applies only to reducing operations, there is no change to the arrays. ## Operators `MappedArray` implements arithmetic, comparison, and logical operators. We can perform basic operations (such as addition) on mapped arrays as if they were NumPy arrays. ```pycon >>> ma ** 2 >>> ma * np.array([1, 2, 3, 4, 5, 6]) >>> ma + ma ``` !!! note Ensure that your `MappedArray` operand is on the left if the other operand is an array. If two `MappedArray` operands have different metadata, will copy metadata from the first one, but at least their `id_arr` and `col_arr` must match. ## Indexing Like any other class subclassing `vectorbtpro.base.wrapping.Wrapping`, we can do pandas indexing on a `MappedArray` instance, which forwards indexing operation to each object with columns: ```pycon >>> ma['a'].values array([10., 11., 12.]) >>> grouped_ma['first'].values array([10., 11., 12., 13., 14., 15.]) ``` !!! note Changing index (time axis) is not supported. The object should be treated as a Series rather than a DataFrame; for example, use `some_field.iloc[0]` instead of `some_field.iloc[:, 0]` to get the first column. Indexing behavior depends solely upon `vectorbtpro.base.wrapping.ArrayWrapper`. For example, if `group_select` is enabled indexing will be performed on groups, otherwise on single columns. ## Caching `MappedArray` supports caching. If a method or a property requires heavy computation, it's wrapped with `vectorbtpro.utils.decorators.cached_method` and `vectorbtpro.utils.decorators.cached_property` respectively. Caching can be disabled globally in `vectorbtpro._settings.caching`. !!! note Because of caching, class is meant to be immutable and all properties are read-only. To change any attribute, use the `MappedArray.replace` method and pass changes as keyword arguments. ## Saving and loading Like any other class subclassing `vectorbtpro.utils.pickling.Pickleable`, we can save a `MappedArray` instance to the disk with `MappedArray.save` and load it with `MappedArray.load`. ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `MappedArray.metrics`. Metric for mapped arrays are similar to that for `vectorbtpro.generic.accessors.GenericAccessor`. ```pycon >>> ma.stats(column='a') Start x End z Period 3 days 00:00:00 Count 3 Mean 11.0 Std 1.0 Min 10.0 Median 11.0 Max 12.0 Min Index x Max Index z Name: a, dtype: object ``` The main difference unfolds once the mapped array has a mapping: values are then considered as categorical and usual statistics are meaningless to compute. For this case, `MappedArray.stats` returns the value counts: ```pycon >>> mapping = {v: "test_" + str(v) for v in np.unique(ma.values)} >>> ma.stats(column='a', settings=dict(mapping=mapping)) Start x End z Period 3 days 00:00:00 Count 3 Value Counts: test_10.0 1 Value Counts: test_11.0 1 Value Counts: test_12.0 1 Value Counts: test_13.0 0 Value Counts: test_14.0 0 Value Counts: test_15.0 0 Value Counts: test_16.0 0 Value Counts: test_17.0 0 Value Counts: test_18.0 0 Name: a, dtype: object ``` `MappedArray.stats` also supports (re-)grouping: ```pycon >>> grouped_ma.stats(column='first') Start x End z Period 3 days 00:00:00 Count 6 Mean 12.5 Std 1.870829 Min 10.0 Median 12.5 Max 15.0 Min Index x Max Index z Name: first, dtype: object ``` ## Plots We can build histograms and boxplots of `MappedArray` directly: ```pycon >>> ma.boxplot().show() ``` ![](/assets/images/api/mapped_boxplot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/mapped_boxplot.dark.svg#only-dark){: .iimg loading=lazy } To use scatterplots or any other plots that require index, convert to pandas first: ```pycon >>> ma.to_pd().vbt.plot().show() ``` ![](/assets/images/api/mapped_to_pd_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/mapped_to_pd_plot.dark.svg#only-dark){: .iimg loading=lazy } !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `MappedArray.subplots`. `MappedArray` class has a single subplot based on `MappedArray.to_pd` and `vectorbtpro.generic.accessors.GenericAccessor.plot`. """ import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base.merging import concat_arrays, column_stack_arrays from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.base.reshaping import to_1d_array, to_dict, index_to_series, index_to_frame from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.analyzable import Analyzable from vectorbtpro.records import nb from vectorbtpro.records.col_mapper import ColumnMapper from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.utils import checks from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.array_ import index_repeating_rows_nb from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import hybrid_method, cached_method from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods from vectorbtpro.utils.mapping import to_value_mapping, apply_mapping from vectorbtpro.utils.warnings_ import warn __all__ = [ "MappedArray", ] MappedArrayT = tp.TypeVar("MappedArrayT", bound="MappedArray") def combine_mapped_with_other( self: MappedArrayT, other: tp.Union["MappedArray", tp.ArrayLike], np_func: tp.Callable[[tp.ArrayLike, tp.ArrayLike], tp.Array1d], ) -> MappedArrayT: """Combine `MappedArray` with other compatible object. If other object is also `MappedArray`, their `id_arr` and `col_arr` must match.""" if isinstance(other, MappedArray): checks.assert_array_equal(self.id_arr, other.id_arr) checks.assert_array_equal(self.col_arr, other.col_arr) other = other.values return self.replace(mapped_arr=np_func(self.values, other)) @attach_binary_magic_methods(combine_mapped_with_other) @attach_unary_magic_methods(lambda self, np_func: self.replace(mapped_arr=np_func(self.values))) class MappedArray(Analyzable): """Exposes methods for reducing, converting, and plotting arrays mapped by `vectorbtpro.records.base.Records` class. Args: wrapper (ArrayWrapper): Array wrapper. See `vectorbtpro.base.wrapping.ArrayWrapper`. mapped_arr (array_like): A one-dimensional array of mapped record values. col_arr (array_like): A one-dimensional column array. Must be of the same size as `mapped_arr`. id_arr (array_like): A one-dimensional id array. Defaults to simple range. Must be of the same size as `mapped_arr`. idx_arr (array_like): A one-dimensional index array. Optional. Must be of the same size as `mapped_arr`. mapping (namedtuple, dict or callable): Mapping. col_mapper (ColumnMapper): Column mapper if already known. !!! note It depends upon `wrapper` and `col_arr`, so make sure to invalidate `col_mapper` upon creating a `MappedArray` instance with a modified `wrapper` or `col_arr. `MappedArray.replace` does it automatically. **kwargs: Custom keyword arguments passed to the config. Useful if any subclass wants to extend the config. """ @hybrid_method def row_stack( cls_or_self: tp.MaybeType[MappedArrayT], *objs: tp.MaybeTuple[MappedArrayT], wrapper_kwargs: tp.KwargsLike = None, **kwargs, ) -> MappedArrayT: """Stack multiple `MappedArray` instances along rows. Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers. !!! note Will produce a column-sorted array.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, MappedArray): raise TypeError("Each object to be merged must be an instance of MappedArray") if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs) if "col_mapper" not in kwargs: kwargs["col_mapper"] = ColumnMapper.row_stack( *[obj.col_mapper for obj in objs], wrapper=kwargs["wrapper"], ) if "mapped_arr" not in kwargs: mapped_arrs = [] for col in range(kwargs["wrapper"].shape_2d[1]): for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: if col > 0 and obj.wrapper.shape_2d[1] == 1: mapped_arrs.append(obj.mapped_arr[col_idxs]) elif col_lens[col] > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens mapped_arrs.append(obj.mapped_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]]) kwargs["mapped_arr"] = concat_arrays(mapped_arrs) if "col_arr" not in kwargs: kwargs["col_arr"] = kwargs["col_mapper"].col_arr if "idx_arr" not in kwargs: stack_idx_arrs = True for obj in objs: if obj.idx_arr is None: stack_idx_arrs = False break if stack_idx_arrs: idx_arrs = [] for col in range(kwargs["wrapper"].shape_2d[1]): n_rows_sum = 0 for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: if col > 0 and obj.wrapper.shape_2d[1] == 1: idx_arrs.append(obj.idx_arr[col_idxs] + n_rows_sum) elif col_lens[col] > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens col_idx_arr = obj.idx_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]] idx_arrs.append(col_idx_arr + n_rows_sum) n_rows_sum += obj.wrapper.shape_2d[0] kwargs["idx_arr"] = concat_arrays(idx_arrs) if "id_arr" not in kwargs: id_arrs = [] for col in range(kwargs["wrapper"].shape_2d[1]): from_id = 0 for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: if col > 0 and obj.wrapper.shape_2d[1] == 1: id_arrs.append(obj.id_arr[col_idxs] + from_id) elif col_lens[col] > 0: col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens id_arrs.append(obj.id_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]] + from_id) if len(id_arrs) > 0 and len(id_arrs[-1]) > 0: from_id = id_arrs[-1].max() + 1 kwargs["id_arr"] = concat_arrays(id_arrs) kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) @hybrid_method def column_stack( cls_or_self: tp.MaybeType[MappedArrayT], *objs: tp.MaybeTuple[MappedArrayT], wrapper_kwargs: tp.KwargsLike = None, get_indexer_kwargs: tp.KwargsLike = None, **kwargs, ) -> MappedArrayT: """Stack multiple `MappedArray` instances along columns. Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers. `get_indexer_kwargs` are passed to [pandas.Index.get_indexer](https://pandas.pydata.org/docs/reference/api/pandas.Index.get_indexer.html) to translate old indices to new ones after the reindexing operation. !!! note Will produce a column-sorted array.""" if not isinstance(cls_or_self, type): objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, MappedArray): raise TypeError("Each object to be merged must be an instance of MappedArray") if get_indexer_kwargs is None: get_indexer_kwargs = {} if "wrapper" not in kwargs: if wrapper_kwargs is None: wrapper_kwargs = {} kwargs["wrapper"] = ArrayWrapper.column_stack( *[obj.wrapper for obj in objs], **wrapper_kwargs, ) if "col_mapper" not in kwargs: kwargs["col_mapper"] = ColumnMapper.column_stack( *[obj.col_mapper for obj in objs], wrapper=kwargs["wrapper"], ) if "mapped_arr" not in kwargs: mapped_arrs = [] for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: mapped_arrs.append(obj.mapped_arr[col_idxs]) kwargs["mapped_arr"] = concat_arrays(mapped_arrs) if "col_arr" not in kwargs: kwargs["col_arr"] = kwargs["col_mapper"].col_arr if "idx_arr" not in kwargs: stack_idx_arrs = True for obj in objs: if obj.idx_arr is None: stack_idx_arrs = False break if stack_idx_arrs: idx_arrs = [] for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: old_idxs = obj.idx_arr[col_idxs] if not obj.wrapper.index.equals(kwargs["wrapper"].index): new_idxs = kwargs["wrapper"].index.get_indexer( obj.wrapper.index[old_idxs], **get_indexer_kwargs, ) else: new_idxs = old_idxs idx_arrs.append(new_idxs) kwargs["idx_arr"] = concat_arrays(idx_arrs) if "id_arr" not in kwargs: id_arrs = [] for obj in objs: col_idxs, col_lens = obj.col_mapper.col_map if len(col_idxs) > 0: id_arrs.append(obj.id_arr[col_idxs]) kwargs["id_arr"] = concat_arrays(id_arrs) kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs) kwargs = cls.resolve_stack_kwargs(*objs, **kwargs) return cls(**kwargs) def __init__( self, wrapper: ArrayWrapper, mapped_arr: tp.ArrayLike, col_arr: tp.ArrayLike, idx_arr: tp.Optional[tp.ArrayLike] = None, id_arr: tp.Optional[tp.ArrayLike] = None, mapping: tp.Optional[tp.MappingLike] = None, col_mapper: tp.Optional[ColumnMapper] = None, **kwargs, ) -> None: mapped_arr = np.asarray(mapped_arr) col_arr = np.asarray(col_arr) checks.assert_shape_equal(mapped_arr, col_arr, axis=0) if idx_arr is not None: idx_arr = np.asarray(idx_arr) checks.assert_shape_equal(mapped_arr, idx_arr, axis=0) if col_mapper is None: col_mapper = ColumnMapper(wrapper, col_arr) if id_arr is None: id_arr = col_mapper.new_id_arr else: id_arr = np.asarray(id_arr) checks.assert_shape_equal(mapped_arr, id_arr, axis=0) Analyzable.__init__( self, wrapper, mapped_arr=mapped_arr, col_arr=col_arr, idx_arr=idx_arr, id_arr=id_arr, mapping=mapping, col_mapper=col_mapper, **kwargs, ) self._mapped_arr = mapped_arr self._col_arr = col_arr self._idx_arr = idx_arr self._id_arr = id_arr self._mapping = mapping self._col_mapper = col_mapper # Only slices of rows can be selected self._range_only_select = True def replace(self: MappedArrayT, **kwargs) -> MappedArrayT: """See `vectorbtpro.utils.config.Configured.replace`. Also, makes sure that `MappedArray.col_mapper` is not passed to the new instance.""" if self.config.get("col_mapper", None) is not None: if "wrapper" in kwargs: if self.wrapper is not kwargs.get("wrapper"): kwargs["col_mapper"] = None if "col_arr" in kwargs: if self.col_arr is not kwargs.get("col_arr"): kwargs["col_mapper"] = None return Analyzable.replace(self, **kwargs) def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict: """Perform indexing on `MappedArray` and return metadata.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta( *args, column_only_select=self.column_only_select, range_only_select=self.range_only_select, group_select=self.group_select, **kwargs, ) new_indices, new_col_arr = self.col_mapper.select_cols(wrapper_meta["col_idxs"]) new_mapped_arr = self.values[new_indices] if self.idx_arr is not None: new_idx_arr = self.idx_arr[new_indices] else: new_idx_arr = None new_id_arr = self.id_arr[new_indices] if wrapper_meta["rows_changed"] and new_idx_arr is not None: row_idxs = wrapper_meta["row_idxs"] mask = (new_idx_arr >= row_idxs.start) & (new_idx_arr < row_idxs.stop) new_indices = new_indices[mask] new_mapped_arr = new_mapped_arr[mask] new_col_arr = new_col_arr[mask] if new_idx_arr is not None: new_idx_arr = new_idx_arr[mask] - row_idxs.start new_id_arr = new_id_arr[mask] return dict( wrapper_meta=wrapper_meta, new_indices=new_indices, new_mapped_arr=new_mapped_arr, new_col_arr=new_col_arr, new_idx_arr=new_idx_arr, new_id_arr=new_id_arr, ) def indexing_func(self: MappedArrayT, *args, mapped_meta: tp.DictLike = None, **kwargs) -> MappedArrayT: """Perform indexing on `MappedArray`.""" if mapped_meta is None: mapped_meta = self.indexing_func_meta(*args, **kwargs) return self.replace( wrapper=mapped_meta["wrapper_meta"]["new_wrapper"], mapped_arr=mapped_meta["new_mapped_arr"], col_arr=mapped_meta["new_col_arr"], id_arr=mapped_meta["new_id_arr"], idx_arr=mapped_meta["new_idx_arr"], ) def resample_meta(self: MappedArrayT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict: """Perform resampling on `MappedArray` and return metadata.""" if wrapper_meta is None: wrapper_meta = self.wrapper.resample_meta(*args, **kwargs) if isinstance(wrapper_meta["resampler"], Resampler): _resampler = wrapper_meta["resampler"] else: _resampler = Resampler.from_pd_resampler(wrapper_meta["resampler"]) if self.idx_arr is not None: index_map = _resampler.map_to_target_index(return_index=False) new_idx_arr = index_map[self.idx_arr] else: new_idx_arr = None return dict(wrapper_meta=wrapper_meta, new_idx_arr=new_idx_arr) def resample(self: MappedArrayT, *args, mapped_meta: tp.DictLike = None, **kwargs) -> MappedArrayT: """Perform resampling on `MappedArray`.""" if mapped_meta is None: mapped_meta = self.resample_meta(*args, **kwargs) return self.replace( wrapper=mapped_meta["wrapper_meta"]["new_wrapper"], idx_arr=mapped_meta["new_idx_arr"], ) @property def mapped_arr(self) -> tp.Array1d: """Mapped array.""" return self._mapped_arr @property def values(self) -> tp.Array1d: """Mapped array.""" return self.mapped_arr def to_readable( self, title: str = "Value", only_values: bool = False, expand_columns: bool = False, **kwargs, ) -> tp.SeriesFrame: """Get values in a human-readable format.""" values = pd.Series(self.apply_mapping(**kwargs).values, name=title) if only_values: return pd.Series(values, name=title) new_columns = list() new_columns.append(pd.Series(self.id_arr, name="Id")) column_index = self.wrapper.columns[self.col_arr] if expand_columns and isinstance(column_index, pd.MultiIndex): column_frame = index_to_frame(column_index, reset_index=True) new_columns.append(column_frame.add_prefix("Column: ")) else: column_sr = index_to_series(column_index, reset_index=True) if expand_columns and self.wrapper.ndim == 2 and column_sr.name is not None: new_columns.append(column_sr.rename(f"Column: {column_sr.name}")) else: new_columns.append(column_sr.rename("Column")) if self.idx_arr is not None: new_columns.append(pd.Series(self.wrapper.index[self.idx_arr], name="Index")) new_columns.append(values) return pd.concat(new_columns, axis=1) @property def mapped_readable(self) -> tp.SeriesFrame: """`MappedArray.to_readable` with default arguments.""" return self.to_readable() readable = mapped_readable def __len__(self) -> int: return len(self.values) @property def col_arr(self) -> tp.Array1d: """Column array.""" return self._col_arr @property def col_mapper(self) -> ColumnMapper: """Column mapper. See `vectorbtpro.records.col_mapper.ColumnMapper`.""" return self._col_mapper @property def idx_arr(self) -> tp.Optional[tp.Array1d]: """Index array.""" return self._idx_arr @property def id_arr(self) -> tp.Array1d: """Id array.""" return self._id_arr @property def mapping(self) -> tp.Optional[tp.MappingLike]: """Mapping.""" return self._mapping # ############# Sorting ############# # @cached_method def is_sorted(self, incl_id: bool = False, jitted: tp.JittedOption = None) -> bool: """Check whether mapped array is sorted.""" if incl_id: func = jit_reg.resolve_option(nb.is_col_id_sorted_nb, jitted) return func(self.col_arr, self.id_arr) func = jit_reg.resolve_option(nb.is_col_sorted_nb, jitted) return func(self.col_arr) def sort( self: MappedArrayT, incl_id: bool = False, idx_arr: tp.Optional[tp.Array1d] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> MappedArrayT: """Sort mapped array by column array (primary) and id array (secondary, optional). `**kwargs` are passed to `MappedArray.replace`.""" if idx_arr is None: idx_arr = self.idx_arr if self.is_sorted(incl_id=incl_id): return self.replace(idx_arr=idx_arr, **kwargs).regroup(group_by) if incl_id: ind = np.lexsort((self.id_arr, self.col_arr)) # expensive! else: ind = np.argsort(self.col_arr) return self.replace( mapped_arr=self.values[ind], col_arr=self.col_arr[ind], id_arr=self.id_arr[ind], idx_arr=idx_arr[ind] if idx_arr is not None else None, **kwargs, ).regroup(group_by) # ############# Filtering ############# # def apply_mask( self: MappedArrayT, mask: tp.Array1d, idx_arr: tp.Optional[tp.Array1d] = None, group_by: tp.GroupByLike = None, **kwargs, ) -> MappedArrayT: """Return a new class instance, filtered by mask. `**kwargs` are passed to `MappedArray.replace`.""" if idx_arr is None: idx_arr = self.idx_arr mask_indices = np.flatnonzero(mask) return self.replace( mapped_arr=np.take(self.values, mask_indices), col_arr=np.take(self.col_arr, mask_indices), id_arr=np.take(self.id_arr, mask_indices), idx_arr=np.take(idx_arr, mask_indices) if idx_arr is not None else None, **kwargs, ).regroup(group_by) def top_n_mask( self, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, ) -> tp.Array1d: """Return mask of top N elements in each column/group.""" col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.top_n_mapped_nb, jitted) func = ch_reg.resolve_option(func, chunked) return func(self.values, col_map, n) def bottom_n_mask( self, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, ) -> tp.Array1d: """Return mask of bottom N elements in each column/group.""" col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.bottom_n_mapped_nb, jitted) func = ch_reg.resolve_option(func, chunked) return func(self.values, col_map, n) def top_n( self: MappedArrayT, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArrayT: """Filter top N elements from each column/group.""" return self.apply_mask(self.top_n_mask(n, group_by=group_by, jitted=jitted, chunked=chunked), **kwargs) def bottom_n( self: MappedArrayT, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArrayT: """Filter bottom N elements from each column/group.""" return self.apply_mask(self.bottom_n_mask(n, group_by=group_by, jitted=jitted, chunked=chunked), **kwargs) # ############# Mapping ############# # def resolve_mapping(self, mapping: tp.Union[None, bool, tp.MappingLike] = None) -> tp.Optional[tp.Mapping]: """Resolve mapping. Set `mapping` to False to disable mapping completely.""" if mapping is None or mapping is True: mapping = self.mapping if isinstance(mapping, bool): if not mapping: return None if isinstance(mapping, str): if mapping.lower() == "index": mapping = self.wrapper.index elif mapping.lower() == "columns": mapping = self.wrapper.columns elif mapping.lower() == "groups": mapping = self.wrapper.get_columns() mapping = to_value_mapping(mapping) return mapping def apply_mapping( self: MappedArrayT, mapping: tp.Union[None, bool, tp.MappingLike] = None, mapping_kwargs: tp.KwargsLike = None, **kwargs, ) -> MappedArrayT: """Apply mapping on each element.""" mapping = self.resolve_mapping(mapping) new_mapped_arr = apply_mapping(self.values, mapping, **resolve_dict(mapping_kwargs)) return self.replace(mapped_arr=new_mapped_arr, **kwargs) def to_index(self, minus_one_to_zero: bool = False) -> tp.Index: """Convert to index. If `minus_one_to_zero` is True, index -1 will automatically become 0. Otherwise, will throw an error.""" if np.isin(-1, self.values): nan_mask = self.values == -1 values = self.values.copy() values[nan_mask] = 0 if minus_one_to_zero: return self.wrapper.index[values] if pd.api.types.is_integer_dtype(self.wrapper.index): new_values = self.wrapper.index.values[values] new_values[nan_mask] = -1 return pd.Index(new_values, name=self.wrapper.index.name) if isinstance(self.wrapper.index, pd.DatetimeIndex): new_values = self.wrapper.index.values[values] new_values[nan_mask] = np.datetime64("NaT") return pd.Index(new_values, name=self.wrapper.index.name) new_values = self.wrapper.index.values[values] new_values[nan_mask] = np.nan return pd.Index(new_values, name=self.wrapper.index.name) return self.wrapper.index[self.values] def to_columns(self) -> tp.Index: """Convert to columns.""" if np.isin(-1, self.values): raise ValueError("Cannot get index at position -1") return self.wrapper.columns[self.values] @hybrid_method def apply( cls_or_self: tp.MaybeType[MappedArrayT], apply_func_nb: tp.Union[tp.ApplyFunc, tp.ApplyMetaFunc], *args, group_by: tp.GroupByLike = None, apply_per_group: bool = False, dtype: tp.Optional[tp.DTypeLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, col_mapper: tp.Optional[ColumnMapper] = None, **kwargs, ) -> MappedArrayT: """Apply function on mapped array per column/group. Returns a new mapped array. Applies per group of columns if `apply_per_group` is True. See `vectorbtpro.records.nb.apply_nb`. For details on the meta version, see `vectorbtpro.records.nb.apply_meta_nb`. `**kwargs` are passed to `MappedArray.replace`.""" if isinstance(cls_or_self, type): checks.assert_not_none(col_mapper, arg_name="col_mapper") col_map = col_mapper.get_col_map(group_by=group_by if apply_per_group else False) func = jit_reg.resolve_option(nb.apply_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) mapped_arr = func(len(col_mapper.col_arr), col_map, apply_func_nb, *args) mapped_arr = np.asarray(mapped_arr, dtype=dtype) return MappedArray(col_mapper.wrapper, mapped_arr, col_mapper.col_arr, col_mapper=col_mapper, **kwargs) else: col_map = cls_or_self.col_mapper.get_col_map(group_by=group_by if apply_per_group else False) func = jit_reg.resolve_option(nb.apply_nb, jitted) func = ch_reg.resolve_option(func, chunked) mapped_arr = func(cls_or_self.values, col_map, apply_func_nb, *args) mapped_arr = np.asarray(mapped_arr, dtype=dtype) return cls_or_self.replace(mapped_arr=mapped_arr, **kwargs).regroup(group_by) # ############# Reducing ############# # def reduce_segments( self: MappedArrayT, segment_arr: tp.Union[str, tp.MaybeTuple[tp.Array1d]], reduce_func_nb: tp.Union[str, tp.ReduceFunc], *args, idx_arr: tp.Optional[tp.Array1d] = None, group_by: tp.GroupByLike = None, apply_per_group: bool = False, dtype: tp.Optional[tp.DTypeLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> MappedArrayT: """Reduce each segment of values in mapped array. Returns a new mapped array. `segment_arr` must be an array of integers increasing per column, each indicating a segment. It must have the same length as the mapped array. You can also pass a list of such arrays. In this case, each unique combination of values will be considered a single segment. Can also pass the string "idx" to use the index array. `reduce_func_nb` can be a string denoting the suffix of a reducing function from `vectorbtpro.generic.nb`. For example, "sum" will refer to "sum_reduce_nb". !!! warning Each segment or combination of segments in `segment_arr` is assumed to be coherent and non-repeating. That is, `np.array([0, 1, 0])` for a single column annotates three different segments, not two. See `vectorbtpro.utils.array_.index_repeating_rows_nb`. !!! hint Use `MappedArray.sort` to bring the mapped array to the desired order, if required. Applies per group of columns if `apply_per_group` is True. See `vectorbtpro.records.nb.reduce_mapped_segments_nb`. `**kwargs` are passed to `MappedArray.replace`.""" if idx_arr is None: if self.idx_arr is None: raise ValueError("Must pass idx_arr") idx_arr = self.idx_arr col_map = self.col_mapper.get_col_map(group_by=group_by if apply_per_group else False) if isinstance(segment_arr, str): if segment_arr.lower() == "idx": segment_arr = idx_arr else: raise ValueError(f"Invalid segment_arr: '{segment_arr}'") if isinstance(segment_arr, tuple): stacked_segment_arr = column_stack_arrays(segment_arr) segment_arr = index_repeating_rows_nb(stacked_segment_arr) if isinstance(reduce_func_nb, str): reduce_func_nb = getattr(generic_nb, reduce_func_nb + "_reduce_nb") func = jit_reg.resolve_option(nb.reduce_mapped_segments_nb, jitted) func = ch_reg.resolve_option(func, chunked) new_mapped_arr, new_col_arr, new_idx_arr, new_id_arr = func( self.values, idx_arr, self.id_arr, col_map, segment_arr, reduce_func_nb, *args, ) new_mapped_arr = np.asarray(new_mapped_arr, dtype=dtype) return self.replace( mapped_arr=new_mapped_arr, col_arr=new_col_arr, idx_arr=new_idx_arr, id_arr=new_id_arr, **kwargs, ).regroup(group_by) @hybrid_method def reduce( cls_or_self, reduce_func_nb: tp.Union[ tp.ReduceFunc, tp.MappedReduceMetaFunc, tp.ReduceToArrayFunc, tp.MappedReduceToArrayMetaFunc ], *args, idx_arr: tp.Optional[tp.Array1d] = None, returns_array: bool = False, returns_idx: bool = False, to_index: bool = True, fill_value: tp.Scalar = np.nan, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, col_mapper: tp.Optional[ColumnMapper] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeriesFrame: """Reduce mapped array by column/group. Set `returns_array` to True if `reduce_func_nb` returns an array. Set `returns_idx` to True if `reduce_func_nb` returns row index/position. Must pass `idx_arr`. Set `to_index` to True to return labels instead of positions. Use `fill_value` to set the default value. For implementation details, see * `vectorbtpro.records.nb.reduce_mapped_nb` if `returns_array` is False and `returns_idx` is False * `vectorbtpro.records.nb.reduce_mapped_to_idx_nb` if `returns_array` is False and `returns_idx` is True * `vectorbtpro.records.nb.reduce_mapped_to_array_nb` if `returns_array` is True and `returns_idx` is False * `vectorbtpro.records.nb.reduce_mapped_to_idx_array_nb` if `returns_array` is True and `returns_idx` is True For implementation details on the meta versions, see * `vectorbtpro.records.nb.reduce_mapped_meta_nb` if `returns_array` is False and `returns_idx` is False * `vectorbtpro.records.nb.reduce_mapped_to_idx_meta_nb` if `returns_array` is False and `returns_idx` is True * `vectorbtpro.records.nb.reduce_mapped_to_array_meta_nb` if `returns_array` is True and `returns_idx` is False * `vectorbtpro.records.nb.reduce_mapped_to_idx_array_meta_nb` if `returns_array` is True and `returns_idx` is True """ if isinstance(cls_or_self, type): checks.assert_not_none(col_mapper, arg_name="col_mapper") col_map = col_mapper.get_col_map(group_by=group_by) if not returns_array: if not returns_idx: func = jit_reg.resolve_option(nb.reduce_mapped_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(col_map, fill_value, reduce_func_nb, *args) else: checks.assert_not_none(idx_arr, arg_name="idx_arr") func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(col_map, idx_arr, fill_value, reduce_func_nb, *args) else: if not returns_idx: func = jit_reg.resolve_option(nb.reduce_mapped_to_array_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(col_map, fill_value, reduce_func_nb, *args) else: checks.assert_not_none(idx_arr, arg_name="idx_arr") func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_array_meta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(col_map, idx_arr, fill_value, reduce_func_nb, *args) wrapper = col_mapper.wrapper else: if idx_arr is None: if cls_or_self.idx_arr is None: if returns_idx: raise ValueError("Must pass idx_arr") idx_arr = cls_or_self.idx_arr col_map = cls_or_self.col_mapper.get_col_map(group_by=group_by) if not returns_array: if not returns_idx: func = jit_reg.resolve_option(nb.reduce_mapped_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.values, col_map, fill_value, reduce_func_nb, *args) else: checks.assert_not_none(idx_arr, arg_name="idx_arr") func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.values, col_map, idx_arr, fill_value, reduce_func_nb, *args) else: if not returns_idx: func = jit_reg.resolve_option(nb.reduce_mapped_to_array_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.values, col_map, fill_value, reduce_func_nb, *args) else: checks.assert_not_none(idx_arr, arg_name="idx_arr") func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_array_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func(cls_or_self.values, col_map, idx_arr, fill_value, reduce_func_nb, *args) wrapper = cls_or_self.wrapper wrap_kwargs = merge_dicts( dict( name_or_index="reduce" if not returns_array else None, to_index=returns_idx and to_index, fillna=-1 if returns_idx else None, dtype=int_ if returns_idx else None, ), wrap_kwargs, ) return wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs) @cached_method def nth( self, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return n-th element of each column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="nth"), wrap_kwargs) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( args=ch.ArgsTaker( None, ) ), ) return self.reduce( jit_reg.resolve_option(generic_nb.nth_reduce_nb, jitted), n, returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def nth_index( self, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return index of n-th element of each column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="nth_index"), wrap_kwargs) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( args=ch.ArgsTaker( None, ) ), ) return self.reduce( jit_reg.resolve_option(generic_nb.nth_index_reduce_nb, jitted), n, returns_array=False, returns_idx=True, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def min( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return min by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="min"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.min_reduce_nb, jitted), returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def max( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return max by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="max"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.max_reduce_nb, jitted), returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def mean( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return mean by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="mean"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.mean_reduce_nb, jitted), returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def median( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return median by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="median"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.median_reduce_nb, jitted), returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def std( self, ddof: int = 1, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return std by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="std"), wrap_kwargs) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( args=ch.ArgsTaker( None, ) ), ) return self.reduce( jit_reg.resolve_option(generic_nb.std_reduce_nb, jitted), ddof, returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def sum( self, fill_value: tp.Scalar = 0.0, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return sum by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="sum"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.sum_reduce_nb, jitted), fill_value=fill_value, returns_array=False, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def idxmin( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return index of min by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="idxmin"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.argmin_reduce_nb, jitted), returns_array=False, returns_idx=True, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def idxmax( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeSeries: """Return index of max by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="idxmax"), wrap_kwargs) return self.reduce( jit_reg.resolve_option(generic_nb.argmax_reduce_nb, jitted), returns_array=False, returns_idx=True, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) @cached_method def describe( self, percentiles: tp.Optional[tp.ArrayLike] = None, ddof: int = 1, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Return statistics by column/group.""" if percentiles is not None: percentiles = to_1d_array(percentiles) else: percentiles = np.array([0.25, 0.5, 0.75]) percentiles = percentiles.tolist() if 0.5 not in percentiles: percentiles.append(0.5) percentiles = np.unique(percentiles) perc_formatted = pd.io.formats.format.format_percentiles(percentiles) index = pd.Index(["count", "mean", "std", "min", *perc_formatted, "max"]) wrap_kwargs = merge_dicts(dict(name_or_index=index), wrap_kwargs) chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(args=ch.ArgsTaker(None, None))) out = self.reduce( jit_reg.resolve_option(generic_nb.describe_reduce_nb, jitted), percentiles, ddof, returns_array=True, returns_idx=False, group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) if wrap_kwargs.get("to_timedelta", False): out.drop("count", axis=0, inplace=True) else: if isinstance(out, pd.DataFrame): out.loc["count", np.isnan(out.loc["count"])] = 0.0 else: if np.isnan(out.loc["count"]): out.loc["count"] = 0.0 return out @cached_method def count(self, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries: """Return number of values by column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="count"), wrap_kwargs) return self.wrapper.wrap_reduced( self.col_mapper.get_col_map(group_by=group_by)[1], group_by=group_by, **wrap_kwargs, ) # ############# Value counts ############# # @cached_method def value_counts( self, axis: int = 1, idx_arr: tp.Optional[tp.Array1d] = None, normalize: bool = False, sort_uniques: bool = True, sort: bool = False, ascending: bool = False, dropna: bool = False, group_by: tp.GroupByLike = None, mapping: tp.Union[None, bool, tp.MappingLike] = None, incl_all_keys: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """See `vectorbtpro.generic.accessors.GenericAccessor.value_counts`.""" checks.assert_in(axis, (-1, 0, 1)) mapping = self.resolve_mapping(mapping) mapped_codes, mapped_uniques = pd.factorize(self.values, sort=False, use_na_sentinel=False) if axis == 0: if idx_arr is None: idx_arr = self.idx_arr checks.assert_not_none(idx_arr, arg_name="idx_arr") func = jit_reg.resolve_option(nb.mapped_value_counts_per_row_nb, jitted) value_counts = func(mapped_codes, len(mapped_uniques), idx_arr, self.wrapper.shape[0]) elif axis == 1: col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.mapped_value_counts_per_col_nb, jitted) func = ch_reg.resolve_option(func, chunked) value_counts = func(mapped_codes, len(mapped_uniques), col_map) else: func = jit_reg.resolve_option(nb.mapped_value_counts_nb, jitted) value_counts = func(mapped_codes, len(mapped_uniques)) if incl_all_keys and mapping is not None: missing_keys = [] for x in mapping: if pd.isnull(x) and pd.isnull(mapped_uniques).any(): continue if x not in mapped_uniques: missing_keys.append(x) if axis == 0 or axis == 1: value_counts = np.vstack((value_counts, np.full((len(missing_keys), value_counts.shape[1]), 0))) else: value_counts = concat_arrays((value_counts, np.full(len(missing_keys), 0))) mapped_uniques = concat_arrays((mapped_uniques, np.array(missing_keys))) nan_mask = np.isnan(mapped_uniques) if dropna: value_counts = value_counts[~nan_mask] mapped_uniques = mapped_uniques[~nan_mask] if sort_uniques: new_indices = mapped_uniques.argsort() value_counts = value_counts[new_indices] mapped_uniques = mapped_uniques[new_indices] if axis == 0 or axis == 1: value_counts_sum = value_counts.sum(axis=1) else: value_counts_sum = value_counts if normalize: value_counts = value_counts / value_counts_sum.sum() if sort: if ascending: new_indices = value_counts_sum.argsort() else: new_indices = (-value_counts_sum).argsort() value_counts = value_counts[new_indices] mapped_uniques = mapped_uniques[new_indices] if axis == 0: wrapper = ArrayWrapper.from_obj(value_counts) value_counts_pd = wrapper.wrap( value_counts, index=mapped_uniques, columns=self.wrapper.index, **resolve_dict(wrap_kwargs), ) elif axis == 1: value_counts_pd = self.wrapper.wrap( value_counts, index=mapped_uniques, group_by=group_by, **resolve_dict(wrap_kwargs), ) else: wrapper = ArrayWrapper.from_obj(value_counts) value_counts_pd = wrapper.wrap( value_counts, index=mapped_uniques, **merge_dicts(dict(columns=["value_counts"]), wrap_kwargs), ) if mapping is not None: value_counts_pd.index = apply_mapping(value_counts_pd.index, mapping, **kwargs) return value_counts_pd # ############# Conflicts ############# # @cached_method def has_conflicts( self, idx_arr: tp.Optional[tp.Array1d] = None, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, ) -> bool: """See `vectorbtpro.records.nb.mapped_has_conflicts_nb`.""" if idx_arr is None: if self.idx_arr is None: raise ValueError("Must pass idx_arr") idx_arr = self.idx_arr col_arr = self.col_mapper.get_col_arr(group_by=group_by) target_shape = self.wrapper.get_shape_2d(group_by=group_by) func = jit_reg.resolve_option(nb.mapped_has_conflicts_nb, jitted) return func(col_arr, idx_arr, target_shape) def coverage_map( self, idx_arr: tp.Optional[tp.Array1d] = None, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.records.nb.mapped_coverage_map_nb`.""" if idx_arr is None: if self.idx_arr is None: raise ValueError("Must pass idx_arr") idx_arr = self.idx_arr col_arr = self.col_mapper.get_col_arr(group_by=group_by) target_shape = self.wrapper.get_shape_2d(group_by=group_by) func = jit_reg.resolve_option(nb.mapped_coverage_map_nb, jitted) out = func(col_arr, idx_arr, target_shape) return self.wrapper.wrap(out, group_by=group_by, **resolve_dict(wrap_kwargs)) # ############# Unstacking ############# # def to_pd( self, idx_arr: tp.Optional[tp.Array1d] = None, reduce_func_nb: tp.Union[None, str, tp.ReduceFunc] = None, reduce_args: tp.ArgsLike = None, dtype: tp.Optional[tp.DTypeLike] = None, ignore_index: bool = False, repeat_index: bool = False, fill_value: float = np.nan, mapping: tp.Union[None, bool, tp.MappingLike] = False, mapping_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, silence_warnings: bool = False, ) -> tp.SeriesFrame: """Unstack mapped array to a Series/DataFrame. If `reduce_func_nb` is not None, will use it to reduce conflicting index segments using `MappedArray.reduce_segments`. * If `ignore_index`, will ignore the index and place values on top of each other in every column/group. See `vectorbtpro.records.nb.ignore_unstack_mapped_nb`. * If `repeat_index`, will repeat any index pointed from multiple values. Otherwise, in case of positional conflicts, will throw a warning and use the latest value. See `vectorbtpro.records.nb.repeat_unstack_mapped_nb`. * Otherwise, see `vectorbtpro.records.nb.unstack_mapped_nb`. !!! note Will raise an error if there are multiple values pointing to the same position. Set `ignore_index` to True in this case. !!! warning Mapped arrays represent information in the most memory-friendly format. Mapping back to pandas may occupy lots of memory if records are sparse.""" if ignore_index: if self.wrapper.ndim == 1: return self.wrapper.wrap( self.values, index=np.arange(len(self.values)), group_by=group_by, **resolve_dict(wrap_kwargs), ) col_map = self.col_mapper.get_col_map(group_by=group_by) func = jit_reg.resolve_option(nb.ignore_unstack_mapped_nb, jitted) out = func(self.values, col_map, fill_value) mapping = self.resolve_mapping(mapping) out = apply_mapping(out, mapping, **resolve_dict(mapping_kwargs)) return self.wrapper.wrap(out, index=np.arange(out.shape[0]), group_by=group_by, **resolve_dict(wrap_kwargs)) if idx_arr is None: if self.idx_arr is None: raise ValueError("Must pass idx_arr") idx_arr = self.idx_arr has_conflicts = self.has_conflicts(idx_arr=idx_arr, group_by=group_by) if has_conflicts and repeat_index: col_arr = self.col_mapper.get_col_arr(group_by=group_by) target_shape = self.wrapper.get_shape_2d(group_by=group_by) func = jit_reg.resolve_option(nb.mapped_coverage_map_nb, jitted) coverage_map = func(col_arr, idx_arr, target_shape) repeat_cnt_arr = np.max(coverage_map, axis=1) func = jit_reg.resolve_option(nb.unstack_index_nb, jitted) unstacked_index = self.wrapper.index[func(repeat_cnt_arr)] func = jit_reg.resolve_option(nb.repeat_unstack_mapped_nb, jitted) out = func(self.values, col_arr, idx_arr, repeat_cnt_arr, target_shape[1], fill_value) mapping = self.resolve_mapping(mapping) out = apply_mapping(out, mapping, **resolve_dict(mapping_kwargs)) wrap_kwargs = merge_dicts(dict(index=unstacked_index), wrap_kwargs) return self.wrapper.wrap(out, group_by=group_by, **wrap_kwargs) else: if has_conflicts: if reduce_func_nb is not None: if reduce_args is None: reduce_args = () self_ = self.reduce_segments( "idx", reduce_func_nb, *reduce_args, idx_arr=idx_arr, group_by=group_by, dtype=dtype, jitted=jitted, chunked=chunked, ) idx_arr = self_.idx_arr else: if not silence_warnings: warn("Multiple values are pointing to the same position. Only the latest value is used.") self_ = self else: self_ = self col_arr = self_.col_mapper.get_col_arr(group_by=group_by) target_shape = self_.wrapper.get_shape_2d(group_by=group_by) func = jit_reg.resolve_option(nb.unstack_mapped_nb, jitted) out = func(self_.values, col_arr, idx_arr, target_shape, fill_value) mapping = self_.resolve_mapping(mapping) out = apply_mapping(out, mapping, **resolve_dict(mapping_kwargs)) return self_.wrapper.wrap(out, group_by=group_by, **resolve_dict(wrap_kwargs)) # ############# Masking ############# # def get_pd_mask( self, idx_arr: tp.Optional[tp.Array1d] = None, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Get mask in form of a Series/DataFrame from row and column indices.""" if idx_arr is None: if self.idx_arr is None: raise ValueError("Must pass idx_arr") idx_arr = self.idx_arr col_arr = self.col_mapper.get_col_arr(group_by=group_by) target_shape = self.wrapper.get_shape_2d(group_by=group_by) out_arr = np.full(target_shape, False) out_arr[idx_arr, col_arr] = True return self.wrapper.wrap(out_arr, group_by=group_by, **resolve_dict(wrap_kwargs)) @property def pd_mask(self) -> tp.SeriesFrame: """`MappedArray.get_pd_mask` with default arguments.""" return self.get_pd_mask() # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `MappedArray.stats`. Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and `stats` from `vectorbtpro._settings.mapped_array`.""" from vectorbtpro._settings import settings mapped_array_stats_cfg = settings["mapped_array"]["stats"] return merge_dicts(Analyzable.stats_defaults.__get__(self), mapped_array_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), count=dict(title="Count", calc_func="count", tags="mapped_array"), mean=dict(title="Mean", calc_func="mean", inv_check_has_mapping=True, tags=["mapped_array", "describe"]), std=dict(title="Std", calc_func="std", inv_check_has_mapping=True, tags=["mapped_array", "describe"]), min=dict(title="Min", calc_func="min", inv_check_has_mapping=True, tags=["mapped_array", "describe"]), median=dict( title="Median", calc_func="median", inv_check_has_mapping=True, tags=["mapped_array", "describe"], ), max=dict(title="Max", calc_func="max", inv_check_has_mapping=True, tags=["mapped_array", "describe"]), idx_min=dict( title="Min Index", calc_func="idxmin", inv_check_has_mapping=True, agg_func=None, tags=["mapped_array", "index"], ), idx_max=dict( title="Max Index", calc_func="idxmax", inv_check_has_mapping=True, agg_func=None, tags=["mapped_array", "index"], ), value_counts=dict( title="Value Counts", calc_func=lambda value_counts: to_dict(value_counts, orient="index_series"), resolve_value_counts=True, check_has_mapping=True, tags=["mapped_array", "value_counts"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def histplot(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.BaseFigure: """Plot histogram by column/group.""" return self.to_pd(group_by=group_by, ignore_index=True).vbt.histplot(**kwargs) def boxplot(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.BaseFigure: """Plot box plot by column/group.""" return self.to_pd(group_by=group_by, ignore_index=True).vbt.boxplot(**kwargs) @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `MappedArray.plots`. Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and `plots` from `vectorbtpro._settings.mapped_array`.""" from vectorbtpro._settings import settings mapped_array_plots_cfg = settings["mapped_array"]["plots"] return merge_dicts(Analyzable.plots_defaults.__get__(self), mapped_array_plots_cfg) _subplots: tp.ClassVar[Config] = HybridConfig( dict( to_pd_plot=dict( check_is_not_grouped=True, plot_func="to_pd.vbt.plot", pass_trace_names=False, tags="mapped_array", ) ) ) @property def subplots(self) -> Config: return self._subplots __pdoc__ = dict() MappedArray.override_metrics_doc(__pdoc__) MappedArray.override_subplots_doc(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for records and mapped arrays. Provides an arsenal of Numba-compiled functions for records and mapped arrays. These only accept NumPy arrays and other Numba-compatible types. !!! note All functions passed as argument must be Numba-compiled. Records must retain the order they were created in.""" import numpy as np from numba import prange from numba.extending import overload from numba.np.numpy_support import as_dtype from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.generic import nb as generic_nb from vectorbtpro.records import chunking as records_ch from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.utils import chunking as ch __all__ = [] # ############# Generation ############# # @register_jitted(cache=True) def generate_ids_nb(col_arr: tp.Array1d, n_cols: int) -> tp.Array1d: """Generate the monotonically increasing id array based on the column index array.""" col_idxs = np.full(n_cols, 0, dtype=int_) out = np.empty_like(col_arr) for c in range(len(col_arr)): out[c] = col_idxs[col_arr[c]] col_idxs[col_arr[c]] += 1 return out # ############# Indexing ############# # @register_jitted(cache=True) def col_lens_nb(col_arr: tp.Array1d, n_cols: int) -> tp.GroupLens: """Get column lengths from sorted column array. !!! note Requires `col_arr` to be in ascending order. This can be done by sorting.""" col_lens = np.full(n_cols, 0, dtype=int_) last_col = -1 for c in range(col_arr.shape[0]): col = col_arr[c] if col < last_col: raise ValueError("col_arr must come in ascending order") last_col = col col_lens[col] += 1 return col_lens @register_jitted(cache=True) def record_col_lens_select_nb( records: tp.RecordArray, col_lens: tp.GroupLens, new_cols: tp.Array1d, ) -> tp.Tuple[tp.Array1d, tp.RecordArray]: """Perform indexing on sorted records using column lengths. Returns new records.""" col_end_idxs = np.cumsum(col_lens) col_start_idxs = col_end_idxs - col_lens n_values = np.sum(col_lens[new_cols]) indices_out = np.empty(n_values, dtype=int_) records_arr_out = np.empty(n_values, dtype=records.dtype) j = 0 for c in range(new_cols.shape[0]): from_r = col_start_idxs[new_cols[c]] to_r = col_end_idxs[new_cols[c]] if from_r == to_r: continue col_records = np.copy(records[from_r:to_r]) col_records["col"][:] = c # don't forget to assign new column indices rang = np.arange(from_r, to_r) indices_out[j : j + rang.shape[0]] = rang records_arr_out[j : j + rang.shape[0]] = col_records j += col_records.shape[0] return indices_out, records_arr_out @register_jitted(cache=True) def col_map_nb(col_arr: tp.Array1d, n_cols: int) -> tp.GroupMap: """Build a map between columns and value indices. Returns an array with indices segmented by column and an array with column lengths. Works well for unsorted column arrays.""" col_lens_out = np.full(n_cols, 0, dtype=int_) for c in range(col_arr.shape[0]): col = col_arr[c] col_lens_out[col] += 1 col_start_idxs = np.cumsum(col_lens_out) - col_lens_out col_idxs_out = np.empty((col_arr.shape[0],), dtype=int_) col_i = np.full(n_cols, 0, dtype=int_) for c in range(col_arr.shape[0]): col = col_arr[c] col_idxs_out[col_start_idxs[col] + col_i[col]] = c col_i[col] += 1 return col_idxs_out, col_lens_out @register_jitted(cache=True) def record_col_map_select_nb( records: tp.RecordArray, col_map: tp.GroupMap, new_cols: tp.Array1d, ) -> tp.Tuple[tp.Array1d, tp.RecordArray]: """Same as `record_col_lens_select_nb` but using column map `col_map`.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens total_count = np.sum(col_lens[new_cols]) indices_out = np.empty(total_count, dtype=int_) records_arr_out = np.empty(total_count, dtype=records.dtype) j = 0 for new_col_i in range(len(new_cols)): new_col = new_cols[new_col_i] col_len = col_lens[new_col] if col_len == 0: continue col_start_idx = col_start_idxs[new_col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_records = np.copy(records[idxs]) col_records["col"][:] = new_col_i indices_out[j : j + col_len] = idxs records_arr_out[j : j + col_len] = col_records j += col_len return indices_out, records_arr_out # ############# Sorting ############# # @register_jitted(cache=True) def is_col_sorted_nb(col_arr: tp.Array1d) -> bool: """Check whether the column array is sorted.""" for i in range(len(col_arr) - 1): if col_arr[i + 1] < col_arr[i]: return False return True @register_jitted(cache=True) def is_col_id_sorted_nb(col_arr: tp.Array1d, id_arr: tp.Array1d) -> bool: """Check whether the column and id arrays are sorted.""" for i in range(len(col_arr) - 1): if col_arr[i + 1] < col_arr[i]: return False if col_arr[i + 1] == col_arr[i] and id_arr[i + 1] < id_arr[i]: return False return True # ############# Filtering ############# # @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( col_map=base_ch.GroupMapSlicer(), n=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def first_n_nb(col_map: tp.GroupMap, n: int) -> tp.Array1d: """Returns the mask of the first N elements.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_idxs.shape[0], False, dtype=np.bool_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[idxs[:n]] = True return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( col_map=base_ch.GroupMapSlicer(), n=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def last_n_nb(col_map: tp.GroupMap, n: int) -> tp.Array1d: """Returns the mask of the last N elements.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_idxs.shape[0], False, dtype=np.bool_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[idxs[-n:]] = True return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( col_map=base_ch.GroupMapSlicer(), n=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def random_n_nb(col_map: tp.GroupMap, n: int) -> tp.Array1d: """Returns the mask of random N elements.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_idxs.shape[0], False, dtype=np.bool_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[np.random.choice(idxs, n, replace=False)] = True return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), n=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def top_n_mapped_nb(mapped_arr: tp.Array1d, col_map: tp.GroupMap, n: int) -> tp.Array1d: """Returns the mask of the top N mapped elements.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(mapped_arr.shape[0], False, dtype=np.bool_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[idxs[np.argsort(mapped_arr[idxs])[-n:]]] = True return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), n=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def bottom_n_mapped_nb(mapped_arr: tp.Array1d, col_map: tp.GroupMap, n: int) -> tp.Array1d: """Returns the mask of the bottom N mapped elements.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(mapped_arr.shape[0], False, dtype=np.bool_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[idxs[np.argsort(mapped_arr[idxs])[:n]]] = True return out # ############# Mapping ############# # @register_chunkable( size=ch.ArraySizer(arg_query="records", axis=0), arg_take_spec=dict(records=ch.ArraySlicer(axis=0), map_func_nb=None, args=ch.ArgsTaker()), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def map_records_nb(records: tp.RecordArray, map_func_nb: tp.RecordsMapFunc, *args) -> tp.Array1d: """Map each record to a single value. `map_func_nb` must accept a single record and `*args`. Must return a single value.""" out = np.empty(records.shape[0], dtype=float_) for ridx in prange(records.shape[0]): out[ridx] = map_func_nb(records[ridx], *args) return out @register_chunkable( size=ch.ArgSizer(arg_query="n_values"), arg_take_spec=dict(n_values=ch.CountAdapter(), map_func_nb=None, args=ch.ArgsTaker()), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def map_records_meta_nb(n_values: int, map_func_nb: tp.MappedReduceMetaFunc, *args) -> tp.Array1d: """Meta version of `map_records_nb`. `map_func_nb` must accept the record index and `*args`. Must return a single value.""" out = np.empty(n_values, dtype=float_) for ridx in prange(n_values): out[ridx] = map_func_nb(ridx, *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), apply_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def apply_nb(arr: tp.Array1d, col_map: tp.GroupMap, apply_func_nb: tp.ApplyFunc, *args) -> tp.Array1d: """Apply function on mapped array or records per column. Returns the same shape as `arr`. `apply_func_nb` must accept the values of the column and `*args`. Must return an array.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.empty(arr.shape[0], dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[idxs] = apply_func_nb(arr[idxs], *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( n_values=ch.CountAdapter(mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), apply_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def apply_meta_nb(n_values: int, col_map: tp.GroupMap, apply_func_nb: tp.ApplyMetaFunc, *args) -> tp.Array1d: """Meta version of `apply_nb`. `apply_func_nb` must accept the indices, the column index, and `*args`. Must return an array.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.empty(n_values, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[idxs] = apply_func_nb(idxs, col, *args) return out # ############# Reducing ############# # @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), id_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), segment_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted def reduce_mapped_segments_nb( mapped_arr: tp.Array1d, idx_arr: tp.Array1d, id_arr: tp.Array1d, col_map: tp.GroupMap, segment_arr: tp.Array1d, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Reduce each segment of values in mapped array. Uses the last column, index, and id of each segment for the new value. `reduce_func_nb` must accept the values in the segment and `*args`. Must return a single value. !!! note Groups must come in ascending order per column, and `idx_arr` and `id_arr` must come in ascending order per segment of values.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.empty(len(mapped_arr), dtype=mapped_arr.dtype) col_arr_out = np.empty(len(mapped_arr), dtype=int_) idx_arr_out = np.empty(len(mapped_arr), dtype=int_) id_arr_out = np.empty(len(mapped_arr), dtype=int_) k = 0 for col in range(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] segment_start_i = 0 for i in range(len(idxs)): r = idxs[i] if i == 0: prev_r = -1 else: prev_r = idxs[i - 1] if i < len(idxs) - 1: next_r = idxs[i + 1] else: next_r = -1 if prev_r != -1: if segment_arr[r] < segment_arr[prev_r]: raise ValueError("segment_arr must come in ascending order per column") elif segment_arr[r] == segment_arr[prev_r]: if idx_arr[r] < idx_arr[prev_r]: raise ValueError("idx_arr must come in ascending order per segment") if id_arr[r] < id_arr[prev_r]: raise ValueError("id_arr must come in ascending order per segment") else: segment_start_i = i if next_r == -1 or segment_arr[r] != segment_arr[next_r]: n_values = i - segment_start_i + 1 if n_values > 1: out[k] = reduce_func_nb(mapped_arr[idxs[segment_start_i : i + 1]], *args) else: out[k] = mapped_arr[r] col_arr_out[k] = col idx_arr_out[k] = idx_arr[r] id_arr_out[k] = id_arr[r] k += 1 return out[:k], col_arr_out[:k], idx_arr_out[:k], id_arr_out[:k] @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_nb( mapped_arr: tp.Array1d, col_map: tp.GroupMap, fill_value: float, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array1d: """Reduce mapped array by column to a single value. Faster than `unstack_mapped_nb` and `vbt.*` used together, and also requires less memory. But does not take advantage of caching. `reduce_func_nb` must accept the mapped array and `*args`. Must return a single value.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_lens.shape[0], fill_value, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[col] = reduce_func_nb(mapped_arr[idxs], *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict(col_map=base_ch.GroupMapSlicer(), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker()), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_meta_nb( col_map: tp.GroupMap, fill_value: float, reduce_func_nb: tp.MappedReduceMetaFunc, *args, ) -> tp.Array1d: """Meta version of `reduce_mapped_nb`. `reduce_func_nb` must accept the mapped indices, the column index, and `*args`. Must return a single value.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_lens.shape[0], fill_value, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[col] = reduce_func_nb(idxs, col, *args) return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_to_idx_nb( mapped_arr: tp.Array1d, col_map: tp.GroupMap, idx_arr: tp.Array1d, fill_value: float, reduce_func_nb: tp.ReduceFunc, *args, ) -> tp.Array1d: """Reduce mapped array by column to an index. Same as `reduce_mapped_nb` except `idx_arr` must be passed. !!! note Must return integers or raise an exception.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_lens.shape[0], fill_value, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_out = reduce_func_nb(mapped_arr[idxs], *args) out[col] = idx_arr[idxs][col_out] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( col_map=base_ch.GroupMapSlicer(), idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="concat", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_to_idx_meta_nb( col_map: tp.GroupMap, idx_arr: tp.Array1d, fill_value: float, reduce_func_nb: tp.MappedReduceMetaFunc, *args, ) -> tp.Array1d: """Meta version of `reduce_mapped_to_idx_nb`. `reduce_func_nb` is the same as in `reduce_mapped_meta_nb`.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full(col_lens.shape[0], fill_value, dtype=float_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_out = reduce_func_nb(idxs, col, *args) out[col] = idx_arr[idxs][col_out] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_to_array_nb( mapped_arr: tp.Array1d, col_map: tp.GroupMap, fill_value: float, reduce_func_nb: tp.ReduceToArrayFunc, *args, ) -> tp.Array2d: """Reduce mapped array by column to an array. `reduce_func_nb` same as for `reduce_mapped_nb` but must return an array.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens for col in range(col_lens.shape[0]): col_len = col_lens[col] if col_len > 0: col_start_idx = col_start_idxs[col] col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len] break col_0_out = reduce_func_nb(mapped_arr[midxs0], *args) out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_) for i in range(col_0_out.shape[0]): out[i, col0] = col_0_out[i] for col in prange(col0 + 1, col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_out = reduce_func_nb(mapped_arr[idxs], *args) for i in range(col_out.shape[0]): out[i, col] = col_out[i] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict(col_map=base_ch.GroupMapSlicer(), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker()), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_to_array_meta_nb( col_map: tp.GroupMap, fill_value: float, reduce_func_nb: tp.MappedReduceToArrayMetaFunc, *args, ) -> tp.Array2d: """Meta version of `reduce_mapped_to_array_nb`. `reduce_func_nb` is the same as in `reduce_mapped_meta_nb`.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens for col in range(col_lens.shape[0]): col_len = col_lens[col] if col_len > 0: col_start_idx = col_start_idxs[col] col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len] break col_0_out = reduce_func_nb(midxs0, col0, *args) out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_) for i in range(col_0_out.shape[0]): out[i, col0] = col_0_out[i] for col in prange(col0 + 1, col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_out = reduce_func_nb(idxs, col, *args) for i in range(col_out.shape[0]): out[i, col] = col_out[i] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), col_map=base_ch.GroupMapSlicer(), idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_to_idx_array_nb( mapped_arr: tp.Array1d, col_map: tp.GroupMap, idx_arr: tp.Array1d, fill_value: float, reduce_func_nb: tp.ReduceToArrayFunc, *args, ) -> tp.Array2d: """Reduce mapped array by column to an index array. Same as `reduce_mapped_to_array_nb` except `idx_arr` must be passed. !!! note Must return integers or raise an exception.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens for col in range(col_lens.shape[0]): col_len = col_lens[col] if col_len > 0: col_start_idx = col_start_idxs[col] col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len] break col_0_out = reduce_func_nb(mapped_arr[midxs0], *args) out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_) for i in range(col_0_out.shape[0]): out[i, col0] = idx_arr[midxs0[col_0_out[i]]] for col in prange(col0 + 1, col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_out = reduce_func_nb(mapped_arr[idxs], *args) for i in range(col_0_out.shape[0]): out[i, col] = idx_arr[idxs[col_out[i]]] return out @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( col_map=base_ch.GroupMapSlicer(), idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def reduce_mapped_to_idx_array_meta_nb( col_map: tp.GroupMap, idx_arr: tp.Array1d, fill_value: float, reduce_func_nb: tp.MappedReduceToArrayMetaFunc, *args, ) -> tp.Array2d: """Meta version of `reduce_mapped_to_idx_array_nb`. `reduce_func_nb` is the same as in `reduce_mapped_meta_nb`.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens for col in range(col_lens.shape[0]): col_len = col_lens[col] if col_len > 0: col_start_idx = col_start_idxs[col] col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len] break col_0_out = reduce_func_nb(midxs0, col0, *args) out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_) for i in range(col_0_out.shape[0]): out[i, col0] = idx_arr[midxs0[col_0_out[i]]] for col in prange(col0 + 1, col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] col_out = reduce_func_nb(idxs, col, *args) for i in range(col_0_out.shape[0]): out[i, col] = idx_arr[idxs[col_out[i]]] return out # ############# Value counts ############# # @register_chunkable( size=base_ch.GroupLensSizer(arg_query="col_map"), arg_take_spec=dict( codes=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper), n_uniques=None, col_map=base_ch.GroupMapSlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def mapped_value_counts_per_col_nb(codes: tp.Array1d, n_uniques: int, col_map: tp.GroupMap) -> tp.Array2d: """Get value counts per column/group of an already factorized mapped array.""" col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full((n_uniques, col_lens.shape[0]), 0, dtype=int_) for col in prange(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[:, col] = generic_nb.value_counts_1d_nb(codes[idxs], n_uniques) return out @register_jitted(cache=True) def mapped_value_counts_per_row_nb( mapped_arr: tp.Array1d, n_uniques: int, idx_arr: tp.Array1d, n_rows: int, ) -> tp.Array2d: """Get value counts per row of an already factorized mapped array.""" out = np.full((n_uniques, n_rows), 0, dtype=int_) for c in range(mapped_arr.shape[0]): out[mapped_arr[c], idx_arr[c]] += 1 return out @register_jitted(cache=True) def mapped_value_counts_nb(mapped_arr: tp.Array1d, n_uniques: int) -> tp.Array2d: """Get value counts globally of an already factorized mapped array.""" out = np.full(n_uniques, 0, dtype=int_) for c in range(mapped_arr.shape[0]): out[mapped_arr[c]] += 1 return out # ############# Coverage ############# # @register_jitted(cache=True) def mapped_has_conflicts_nb(col_arr: tp.Array1d, idx_arr: tp.Array1d, target_shape: tp.Shape) -> bool: """Check whether mapped array has positional conflicts.""" temp = np.zeros(target_shape) for i in range(len(col_arr)): if temp[idx_arr[i], col_arr[i]] > 0: return True temp[idx_arr[i], col_arr[i]] = 1 return False @register_jitted(cache=True) def mapped_coverage_map_nb(col_arr: tp.Array1d, idx_arr: tp.Array1d, target_shape: tp.Shape) -> tp.Array2d: """Get the coverage map of a mapped array. Each element corresponds to the number of times it was referenced (= duplicates of `col_arr` and `idx_arr`). More than one depicts a positional conflict.""" out = np.zeros(target_shape, dtype=int_) for i in range(len(col_arr)): out[idx_arr[i], col_arr[i]] += 1 return out # ############# Unstacking ############# # def _unstack_mapped_nb( mapped_arr, col_arr, idx_arr, target_shape, fill_value, ): nb_enabled = not isinstance(mapped_arr, np.ndarray) if nb_enabled: mapped_arr_dtype = as_dtype(mapped_arr.dtype) fill_value_dtype = as_dtype(fill_value) else: mapped_arr_dtype = mapped_arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(mapped_arr_dtype, fill_value_dtype) def impl(mapped_arr, col_arr, idx_arr, target_shape, fill_value): out = np.full(target_shape, fill_value, dtype=dtype) for r in range(mapped_arr.shape[0]): out[idx_arr[r], col_arr[r]] = mapped_arr[r] return out if not nb_enabled: return impl(mapped_arr, col_arr, idx_arr, target_shape, fill_value) return impl overload(_unstack_mapped_nb)(_unstack_mapped_nb) @register_jitted(cache=True) def unstack_mapped_nb( mapped_arr: tp.Array1d, col_arr: tp.Array1d, idx_arr: tp.Array1d, target_shape: tp.Shape, fill_value: float, ) -> tp.Array2d: """Unstack mapped array using index data.""" return _unstack_mapped_nb(mapped_arr, col_arr, idx_arr, target_shape, fill_value) def _ignore_unstack_mapped_nb(mapped_arr, col_map, fill_value): nb_enabled = not isinstance(mapped_arr, np.ndarray) if nb_enabled: mapped_arr_dtype = as_dtype(mapped_arr.dtype) fill_value_dtype = as_dtype(fill_value) else: mapped_arr_dtype = mapped_arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(mapped_arr_dtype, fill_value_dtype) def impl(mapped_arr, col_map, fill_value): col_idxs, col_lens = col_map col_start_idxs = np.cumsum(col_lens) - col_lens out = np.full((np.max(col_lens), col_lens.shape[0]), fill_value, dtype=dtype) for col in range(col_lens.shape[0]): col_len = col_lens[col] if col_len == 0: continue col_start_idx = col_start_idxs[col] idxs = col_idxs[col_start_idx : col_start_idx + col_len] out[:col_len, col] = mapped_arr[idxs] return out if not nb_enabled: return impl(mapped_arr, col_map, fill_value) return impl overload(_ignore_unstack_mapped_nb)(_ignore_unstack_mapped_nb) @register_jitted(cache=True) def ignore_unstack_mapped_nb(mapped_arr: tp.Array1d, col_map: tp.GroupMap, fill_value: float) -> tp.Array2d: """Unstack mapped array by ignoring index data.""" return _ignore_unstack_mapped_nb(mapped_arr, col_map, fill_value) @register_jitted(cache=True) def unstack_index_nb(repeat_cnt_arr: tp.Array1d) -> tp.Array1d: """Unstack index using the number of times each element must repeat. `repeat_cnt_arr` can be created from the coverage map.""" out = np.empty(np.sum(repeat_cnt_arr), dtype=int_) k = 0 for i in range(len(repeat_cnt_arr)): out[k : k + repeat_cnt_arr[i]] = i k += repeat_cnt_arr[i] return out def _repeat_unstack_mapped_nb( mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value, ): nb_enabled = not isinstance(mapped_arr, np.ndarray) if nb_enabled: mapped_arr_dtype = as_dtype(mapped_arr.dtype) fill_value_dtype = as_dtype(fill_value) else: mapped_arr_dtype = mapped_arr.dtype fill_value_dtype = np.array(fill_value).dtype dtype = np.promote_types(mapped_arr_dtype, fill_value_dtype) def impl(mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value): index_start_arr = np.cumsum(repeat_cnt_arr) - repeat_cnt_arr out = np.full((np.sum(repeat_cnt_arr), n_cols), fill_value, dtype=dtype) temp = np.zeros((len(repeat_cnt_arr), n_cols), dtype=int_) for i in range(len(col_arr)): out[index_start_arr[idx_arr[i]] + temp[idx_arr[i], col_arr[i]], col_arr[i]] = mapped_arr[i] temp[idx_arr[i], col_arr[i]] += 1 return out if not nb_enabled: return impl(mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value) return impl overload(_repeat_unstack_mapped_nb)(_repeat_unstack_mapped_nb) @register_jitted(cache=True) def repeat_unstack_mapped_nb( mapped_arr: tp.Array1d, col_arr: tp.Array1d, idx_arr: tp.Array1d, repeat_cnt_arr: tp.Array1d, n_cols: int, fill_value: float, ) -> tp.Array2d: """Unstack mapped array using repeated index data.""" return _repeat_unstack_mapped_nb(mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules that register objects across vectorbtpro.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.registries.ca_registry import * from vectorbtpro.registries.ch_registry import * from vectorbtpro.registries.jit_registry import * from vectorbtpro.registries.pbar_registry import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Global registry for cacheables. Caching in vectorbt is achieved through a combination of decorators and the registry. Cacheable decorators such as `vectorbtpro.utils.decorators.cacheable` take a function and wrap it with another function that behaves like the wrapped function but also takes care of all caching modalities. But unlike other implementations such as that of `functools.lru_cache`, the actual caching procedure doesn't happen nor are the results stored inside the decorators themselves: decorators just register a so-called "setup" for the wrapped function at the registry (see `CARunSetup`). ## Runnable setups The actual magic happens within a runnable setup: it takes the function that should be called and the arguments that should be passed to this function, looks whether the result should be cached, runs the function, stores the result in the cache, updates the metrics, etc. It then returns the resulting object to the wrapping function, which in turn returns it to the user. Each setup is stateful - it stores the cache, the number of hits and misses, and other metadata. Thus, there can be only one registered setup per each cacheable function globally at a time. To avoid creating new setups for the same function over and over again, each setup can be uniquely identified by its function through hashing: ```pycon >>> from vectorbtpro import * >>> my_func = lambda: np.random.uniform(size=1000000) >>> # Decorator returns a wrapper >>> my_ca_func = vbt.cached(my_func) >>> # Wrapper registers a new setup >>> my_ca_func.get_ca_setup() CARunSetup(registry=, use_cache=True, whitelist=False, cacheable= at 0x7fe14e94cae8>, instance=None, max_size=None, ignore_args=None, cache={}) >>> # Another call won't register a new setup but return the existing one >>> my_ca_func.get_ca_setup() CARunSetup(registry=, use_cache=True, whitelist=False, cacheable= at 0x7fe14e94cae8>, instance=None, max_size=None, ignore_args=None, cache={}) >>> # Only one CARunSetup object per wrapper and optionally the instance the wrapper is bound to >>> hash(my_ca_func.get_ca_setup()) == hash((my_ca_func, None)) True ``` When we call `my_ca_func`, it takes the setup from the registry and calls `CARunSetup.run`. The caching happens by the setup itself and isn't in any way visible to `my_ca_func`. To access the cache or any metric of interest, we can ask the setup: ```pycon >>> my_setup = my_ca_func.get_ca_setup() >>> # Cache is empty >>> my_setup.get_stats() { 'hash': 4792160544297109364, 'string': '>', 'use_cache': True, 'whitelist': False, 'caching_enabled': True, 'hits': 0, 'misses': 0, 'total_size': '0 Bytes', 'total_elapsed': None, 'total_saved': None, 'first_run_time': None, 'last_run_time': None, 'first_hit_time': None, 'last_hit_time': None, 'creation_time': 'now', 'last_update_time': None } >>> # The result is cached >>> my_ca_func() >>> my_setup.get_stats() { 'hash': 4792160544297109364, 'string': '>', 'use_cache': True, 'whitelist': False, 'caching_enabled': True, 'hits': 0, 'misses': 1, 'total_size': '8.0 MB', 'total_elapsed': '11.33 milliseconds', 'total_saved': '0 milliseconds', 'first_run_time': 'now', 'last_run_time': 'now', 'first_hit_time': None, 'last_hit_time': None, 'creation_time': 'now', 'last_update_time': None } >>> # The cached result is retrieved >>> my_ca_func() >>> my_setup.get_stats() { 'hash': 4792160544297109364, 'string': '>', 'use_cache': True, 'whitelist': False, 'caching_enabled': True, 'hits': 1, 'misses': 1, 'total_size': '8.0 MB', 'total_elapsed': '11.33 milliseconds', 'total_saved': '11.33 milliseconds', 'first_run_time': 'now', 'last_run_time': 'now', 'first_hit_time': 'now', 'last_hit_time': 'now', 'creation_time': 'now', 'last_update_time': None } ``` ## Enabling/disabling caching To enable or disable caching, we can invoke `CARunSetup.enable_caching` and `CARunSetup.disable_caching` respectively. This will set `CARunSetup.use_cache` flag to True or False. Even though we expressed our disire to change caching rules, the final decision also depends on the global settings and whether the setup is whitelisted in case caching is disabled globally. This decision is available via `CARunSetup.caching_enabled`: ```pycon >>> my_setup.disable_caching() >>> my_setup.caching_enabled False >>> my_setup.enable_caching() >>> my_setup.caching_enabled True >>> vbt.settings.caching['disable'] = True >>> my_setup.caching_enabled False >>> my_setup.enable_caching() UserWarning: This operation has no effect: caching is disabled globally and this setup is not whitelisted >>> my_setup.enable_caching(force=True) >>> my_setup.caching_enabled True >>> vbt.settings.caching['disable_whitelist'] = True >>> my_setup.caching_enabled False >>> my_setup.enable_caching(force=True) UserWarning: This operation has no effect: caching and whitelisting are disabled globally ``` To disable registration of new setups completely, use `disable_machinery`: ```pycon >>> vbt.settings.caching['disable_machinery'] = True ``` ## Setup hierarchy But what if we wanted to change caching rules for an entire instance or class at once? Even if we changed the setup of every cacheable function declared in the class, how do we make sure that each future subclass or instance inherits the changes that we applied? To account for this, vectorbt provides us with a set of setups that both are stateful and can delegate various operations to their child setups, all the way down to `CARunSetup`. The setup hierarchy follows the inheritance hierarchy in OOP: ![](/assets/images/api/setup_hierarchy.svg){: loading=lazy style="width:700px;" } For example, calling `B.get_ca_setup().disable_caching()` would disable caching for each current and future subclass and instance of `B`, but it won't disable caching for `A` or any other superclass of `B`. In turn, each instance of `B` would then disable caching for each cacheable property and method in that instance. As we see, the propagation of this operation is happening from top to bottom. The reason why unbound setups are stretching outside of their classes in the diagram is because there is no easy way to derive the class when calling a cacheable decorator, thus their functions are considered to be living on their own. When calling `B.f.get_ca_setup().disable_caching()`, we are disabling caching for the function `B.f` for each current and future subclass and instance of `B`, while all other functions remain untouched. But what happens when we enable caching for the class `B` and disable caching for the unbound function `B.f`? Would the future method `b2.f` be cached or not? Quite easy: it would then inherit the state from the setup that has been updated more recently. Here is another illustration of how operations are propagated from parents to children: ![](/assets/images/api/setup_propagation.svg){: loading=lazy style="width:800px;" } The diagram above depicts the following setup hierarchy: ```pycon >>> # Populate setups at init >>> vbt.settings.caching.reset() >>> vbt.settings.caching['register_lazily'] = False >>> class A(vbt.Cacheable): ... @vbt.cached_property ... def f1(self): pass >>> class B(A): ... def f2(self): pass >>> class C(A): ... @vbt.cached_method ... def f2(self): pass >>> b1 = B() >>> c1 = C() >>> c2 = C() >>> print(vbt.prettify(A.get_ca_setup().get_setup_hierarchy())) [ { "parent": "", "children": [ { "parent": "", "children": [ "" ] } ] }, { "parent": "", "children": [ { "parent": "", "children": [ "", "" ] }, { "parent": "", "children": [ "", "" ] } ] } ] >>> print(vbt.prettify(A.f1.get_ca_setup().get_setup_hierarchy())) [ "", "", "" ] >>> print(vbt.prettify(C.f2.get_ca_setup().get_setup_hierarchy())) [ "", "" ] ``` Let's disable caching for the entire `A` class: ```pycon >>> A.get_ca_setup().disable_caching() >>> A.get_ca_setup().use_cache False >>> B.get_ca_setup().use_cache False >>> C.get_ca_setup().use_cache False ``` This disabled caching for `A`, subclasses `B` and `C`, their instances, and any instance function. But it didn't touch unbound functions such as `C.f1` and `C.f2`: ```pycon >>> C.f1.get_ca_setup().use_cache True >>> C.f2.get_ca_setup().use_cache True ``` This is because unbound functions are not children of the classes they are declared in! Still, any future instance method of `C` won't be cached because it looks which parent has been updated more recently: the class or the unbound function. In our case, the class had a more recent update. ```pycon >>> c3 = C() >>> C.f2.get_ca_setup(c3).use_cache False ``` In fact, if we want to disable an entire class but leave one function untouched, we need to perform two operations in a particular order: 1) disable caching on the class and 2) enable caching on the unbound function. ```pycon >>> A.get_ca_setup().disable_caching() >>> C.f2.get_ca_setup().enable_caching() >>> c4 = C() >>> C.f2.get_ca_setup(c4).use_cache True ``` ## Getting statistics The main advantage of having a central registry of setups is that we can easily find any setup registered in any part of vectorbt that matches some condition using `CacheableRegistry.match_setups`. !!! note By default, all setups are registered lazily - no setup is registered until it's run or explicitly called. To change this behavior, set `register_lazily` in the global settings to False. For example, let's look which setups have been registered so far: ```pycon >>> vbt.ca_reg.match_setups(kind=None) { CAClassSetup(registry=, use_cache=None, whitelist=None, cls=), CAClassSetup(registry=, use_cache=None, whitelist=None, cls=), CAInstanceSetup(registry=, use_cache=None, whitelist=None, instance=), CAInstanceSetup(registry=, use_cache=None, whitelist=None, instance=), CAInstanceSetup(registry=, use_cache=None, whitelist=None, instance=), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable= at 0x7fe14e94cae8>, instance=None, max_size=None, ignore_args=None, cache={}), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}), CAUnboundSetup(registry=, use_cache=True, whitelist=False, cacheable=), CAUnboundSetup(registry=, use_cache=True, whitelist=False, cacheable=) } ``` Let's get the runnable setup of any property and method called `f2`: ```pycon >>> vbt.ca_reg.match_setups('f2', kind='runnable') { CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}), CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}) } ``` But there is a better way to get the stats: `CAQueryDelegator.get_stats`. It returns a DataFrame with setup stats as rows: ```pycon >>> vbt.CAQueryDelegator('f2', kind='runnable').get_stats() string use_cache whitelist \\ hash 3506416602224216137 True False -4747092115268118855 True False -4748466030718995055 True False caching_enabled hits misses total_size total_elapsed \\ hash 3506416602224216137 True 0 0 0 Bytes None -4747092115268118855 True 0 0 0 Bytes None -4748466030718995055 True 0 0 0 Bytes None total_saved first_run_time last_run_time first_hit_time \\ hash 3506416602224216137 None None None None -4747092115268118855 None None None None -4748466030718995055 None None None None last_hit_time creation_time last_update_time hash 3506416602224216137 None 9 minutes ago 9 minutes ago -4747092115268118855 None 9 minutes ago 9 minutes ago -4748466030718995055 None 9 minutes ago 9 minutes ago ``` ## Clearing up Instance and runnable setups hold only weak references to their instances such that deleting those instances won't keep them in memory and will automatically remove the setups. To clear all caches: ```pycon >>> vbt.CAQueryDelegator().clear_cache() ``` ## Resetting To reset global caching flags: ```pycon >>> vbt.settings.caching.reset() ``` To remove all setups: ```pycon >>> vbt.CAQueryDelegator(kind=None).deregister() ``` """ import inspect import sys from collections.abc import ValuesView from datetime import datetime, timezone, timedelta from functools import wraps from weakref import ref, ReferenceType import attr import humanize import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.utils import checks, datetime_ as dt from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.base import Base from vectorbtpro.utils.caching import Cacheable from vectorbtpro.utils.decorators import cacheableT, cacheable_property from vectorbtpro.utils.formatting import ptable from vectorbtpro.utils.parsing import Regex, hash_args, UnhashableArgsError, get_func_arg_names from vectorbtpro.utils.profiling import Timer from vectorbtpro.utils.warnings_ import warn __all__ = [ "CacheableRegistry", "ca_reg", "CAQuery", "CARule", "CAQueryDelegator", "get_cache_stats", "print_cache_stats", "clear_cache", "collect_garbage", "flush", "disable_caching", "enable_caching", "CachingDisabled", "with_caching_disabled", "CachingEnabled", "with_caching_enabled", ] __pdoc__ = {} class _GARBAGE: """Sentinel class for garbage.""" def is_cacheable_function(cacheable: tp.Any) -> bool: """Check if `cacheable` is a cacheable function.""" return ( callable(cacheable) and hasattr(cacheable, "is_method") and not cacheable.is_method and hasattr(cacheable, "is_cacheable") and cacheable.is_cacheable ) def is_cacheable_property(cacheable: tp.Any) -> bool: """Check if `cacheable` is a cacheable property.""" return isinstance(cacheable, cacheable_property) def is_cacheable_method(cacheable: tp.Any) -> bool: """Check if `cacheable` is a cacheable method.""" return ( callable(cacheable) and hasattr(cacheable, "is_method") and cacheable.is_method and hasattr(cacheable, "is_cacheable") and cacheable.is_cacheable ) def is_bindable_cacheable(cacheable: tp.Any) -> bool: """Check if `cacheable` is a cacheable that can be bound to an instance.""" return is_cacheable_property(cacheable) or is_cacheable_method(cacheable) def is_cacheable(cacheable: tp.Any) -> bool: """Check if `cacheable` is a cacheable.""" return is_cacheable_function(cacheable) or is_bindable_cacheable(cacheable) def get_obj_id(instance: object) -> tp.Tuple[type, int]: """Get id of an instance.""" return type(instance), id(instance) CAQueryT = tp.TypeVar("CAQueryT", bound="CAQuery") InstanceT = tp.Optional[tp.Union[Cacheable, ReferenceType]] def _instance_converter(instance: InstanceT) -> InstanceT: """Make the reference to the instance weak.""" if instance is not None and instance is not _GARBAGE and not isinstance(instance, ReferenceType): return ref(instance) return instance @define class CAQuery(DefineMixin): """Data class that represents a query for matching and ranking setups.""" cacheable: tp.Optional[tp.Union[tp.Callable, cacheableT, str, Regex]] = define.field(default=None) """Cacheable object or its name (case-sensitive).""" instance: InstanceT = define.field(default=None, converter=_instance_converter) """Weak reference to the instance `CAQuery.cacheable` is bound to.""" cls: tp.Optional[tp.TypeLike] = define.field(default=None) """Class of the instance or its name (case-sensitive) `CAQuery.cacheable` is bound to.""" base_cls: tp.Optional[tp.TypeLike] = define.field(default=None) """Base class of the instance or its name (case-sensitive) `CAQuery.cacheable` is bound to.""" options: tp.Optional[dict] = define.field(default=None) """Options to match.""" @classmethod def parse(cls: tp.Type[CAQueryT], query_like: tp.Any, use_base_cls: bool = True) -> CAQueryT: """Parse a query-like object. !!! note Not all attribute combinations can be safely parsed by this function. For example, you cannot combine cacheable together with options. Usage: ```pycon >>> vbt.CAQuery.parse(lambda x: x) CAQuery(cacheable= at 0x7fd4766c7730>, instance=None, cls=None, base_cls=None, options=None) >>> vbt.CAQuery.parse("a") CAQuery(cacheable='a', instance=None, cls=None, base_cls=None, options=None) >>> vbt.CAQuery.parse("A.a") CAQuery(cacheable='a', instance=None, cls=None, base_cls='A', options=None) >>> vbt.CAQuery.parse("A") CAQuery(cacheable=None, instance=None, cls=None, base_cls='A', options=None) >>> vbt.CAQuery.parse("A", use_base_cls=False) CAQuery(cacheable=None, instance=None, cls='A', base_cls=None, options=None) >>> vbt.CAQuery.parse(vbt.Regex("[A-B]")) CAQuery(cacheable=None, instance=None, cls=None, base_cls=Regex(pattern='[A-B]', flags=0), options=None) >>> vbt.CAQuery.parse(dict(my_option=100)) CAQuery(cacheable=None, instance=None, cls=None, base_cls=None, options={'my_option': 100}) ``` """ if query_like is None: return CAQuery() if isinstance(query_like, CAQuery): return query_like if isinstance(query_like, CABaseSetup): return query_like.query if isinstance(query_like, cacheable_property): return cls(cacheable=query_like) if isinstance(query_like, str) and query_like[0].islower(): return cls(cacheable=query_like) if isinstance(query_like, str) and query_like[0].isupper() and "." in query_like: if use_base_cls: return cls(cacheable=query_like.split(".")[1], base_cls=query_like.split(".")[0]) return cls(cacheable=query_like.split(".")[1], cls=query_like.split(".")[0]) if isinstance(query_like, str) and query_like[0].isupper(): if use_base_cls: return cls(base_cls=query_like) return cls(cls=query_like) if isinstance(query_like, Regex): if use_base_cls: return cls(base_cls=query_like) return cls(cls=query_like) if isinstance(query_like, type): if use_base_cls: return cls(base_cls=query_like) return cls(cls=query_like) if isinstance(query_like, tuple): if use_base_cls: return cls(base_cls=query_like) return cls(cls=query_like) if isinstance(query_like, dict): return cls(options=query_like) if callable(query_like): return cls(cacheable=query_like) return cls(instance=query_like) @property def instance_obj(self) -> tp.Optional[tp.Union[Cacheable, object]]: """Instance object.""" if self.instance is _GARBAGE: return _GARBAGE if self.instance is not None and self.instance() is None: return _GARBAGE return self.instance() if self.instance is not None else None def matches_setup(self, setup: "CABaseSetup") -> bool: """Return whether the setup matches this query. Usage: Let's evaluate various queries: ```pycon >>> class A(vbt.Cacheable): ... @vbt.cached_method(my_option=True) ... def f(self): ... return None >>> class B(A): ... pass >>> @vbt.cached(my_option=False) ... def f(): ... return None >>> a = A() >>> b = B() >>> def match_query(query): ... matched = [] ... if query.matches_setup(A.f.get_ca_setup()): # unbound method ... matched.append('A.f') ... if query.matches_setup(A.get_ca_setup()): # class ... matched.append('A') ... if query.matches_setup(a.get_ca_setup()): # instance ... matched.append('a') ... if query.matches_setup(A.f.get_ca_setup(a)): # instance method ... matched.append('a.f') ... if query.matches_setup(B.f.get_ca_setup()): # unbound method ... matched.append('B.f') ... if query.matches_setup(B.get_ca_setup()): # class ... matched.append('B') ... if query.matches_setup(b.get_ca_setup()): # instance ... matched.append('b') ... if query.matches_setup(B.f.get_ca_setup(b)): # instance method ... matched.append('b.f') ... if query.matches_setup(f.get_ca_setup()): # function ... matched.append('f') ... return matched >>> match_query(vbt.CAQuery()) ['A.f', 'A', 'a', 'a.f', 'B.f', 'B', 'b', 'b.f', 'f'] >>> match_query(vbt.CAQuery(cacheable=A.f)) ['A.f', 'a.f', 'B.f', 'b.f'] >>> match_query(vbt.CAQuery(cacheable=B.f)) ['A.f', 'a.f', 'B.f', 'b.f'] >>> match_query(vbt.CAQuery(cls=A)) ['A', 'a', 'a.f'] >>> match_query(vbt.CAQuery(cls=B)) ['B', 'b', 'b.f'] >>> match_query(vbt.CAQuery(cls=vbt.Regex('[A-B]'))) ['A', 'a', 'a.f', 'B', 'b', 'b.f'] >>> match_query(vbt.CAQuery(base_cls=A)) ['A', 'a', 'a.f', 'B', 'b', 'b.f'] >>> match_query(vbt.CAQuery(base_cls=B)) ['B', 'b', 'b.f'] >>> match_query(vbt.CAQuery(instance=a)) ['a', 'a.f'] >>> match_query(vbt.CAQuery(instance=b)) ['b', 'b.f'] >>> match_query(vbt.CAQuery(instance=a, cacheable='f')) ['a.f'] >>> match_query(vbt.CAQuery(instance=b, cacheable='f')) ['b.f'] >>> match_query(vbt.CAQuery(options=dict(my_option=True))) ['A.f', 'a.f', 'B.f', 'b.f'] >>> match_query(vbt.CAQuery(options=dict(my_option=False))) ['f'] ``` """ if self.cacheable is not None: if not isinstance(setup, (CARunSetup, CAUnboundSetup)): return False if is_cacheable(self.cacheable): if setup.cacheable is not self.cacheable and setup.cacheable.func is not self.cacheable.func: return False elif callable(self.cacheable): if setup.cacheable.func is not self.cacheable: return False elif isinstance(self.cacheable, str): if setup.cacheable.name != self.cacheable: return False elif isinstance(self.cacheable, Regex): if not self.cacheable.matches(setup.cacheable.name): return False else: return False if self.instance_obj is not None: if not isinstance(setup, (CARunSetup, CAInstanceSetup)): return False if setup.instance_obj is not self.instance_obj: return False if self.cls is not None: if not isinstance(setup, (CARunSetup, CAInstanceSetup, CAClassSetup)): return False if isinstance(setup, (CARunSetup, CAInstanceSetup)) and setup.instance_obj is _GARBAGE: return False if isinstance(setup, (CARunSetup, CAInstanceSetup)) and not checks.is_class( type(setup.instance_obj), self.cls, ): return False if isinstance(setup, CAClassSetup) and not checks.is_class(setup.cls, self.cls): return False if self.base_cls is not None: if not isinstance(setup, (CARunSetup, CAInstanceSetup, CAClassSetup)): return False if isinstance(setup, (CARunSetup, CAInstanceSetup)) and setup.instance_obj is _GARBAGE: return False if isinstance(setup, (CARunSetup, CAInstanceSetup)) and not checks.is_subclass_of( type(setup.instance_obj), self.base_cls, ): return False if isinstance(setup, CAClassSetup) and not checks.is_subclass_of(setup.cls, self.base_cls): return False if self.options is not None and len(self.options) > 0: if not isinstance(setup, (CARunSetup, CAUnboundSetup)): return False for k, v in self.options.items(): if k not in setup.cacheable.options or setup.cacheable.options[k] != v: return False return True @property def hash_key(self) -> tuple: return ( self.cacheable, get_obj_id(self.instance_obj) if self.instance_obj is not None else None, self.cls, self.base_cls, tuple(self.options.items()) if self.options is not None else None, ) @define class CARule(DefineMixin): """Data class that represents a rule that should be enforced on setups that match a query.""" query: CAQuery = define.field() """`CAQuery` used in matching.""" enforce_func: tp.Optional[tp.Callable] = define.field() """Function to run on the setup if it has been matched.""" kind: tp.Optional[tp.MaybeIterable[str]] = define.field(default=None) """Kind of a setup to match.""" exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = define.field(default=None) """One or multiple setups to exclude.""" filter_func: tp.Optional[tp.Callable] = define.field(default=None) """Function to filter out a setup.""" def matches_setup(self, setup: "CABaseSetup") -> bool: """Return whether the setup matches the rule.""" if not self.query.matches_setup(setup): return False if self.kind is not None: kind = self.kind if isinstance(kind, str): kind = {kind} else: kind = set(kind) if isinstance(setup, CAClassSetup): setup_kind = "class" elif isinstance(setup, CAInstanceSetup): setup_kind = "instance" elif isinstance(setup, CAUnboundSetup): setup_kind = "unbound" else: setup_kind = "runnable" if setup_kind not in kind: return False if self.exclude is not None: exclude = self.exclude if exclude is None: exclude = set() if isinstance(exclude, CABaseSetup): exclude = {exclude} else: exclude = set(exclude) if setup in exclude: return False if self.filter_func is not None: if not self.filter_func(setup): return False return True def enforce(self, setup: "CABaseSetup") -> None: """Run `CARule.enforce_func` on the setup if it has been matched.""" if self.matches_setup(setup): self.enforce_func(setup) @property def hash_key(self) -> tuple: return ( self.query, self.enforce_func, self.kind, self.exclude if isinstance(self.exclude, CABaseSetup) else tuple(self.exclude), self.filter_func, ) class CacheableRegistry(Base): """Class for registering setups of cacheables.""" def __init__(self) -> None: self._class_setups = dict() self._instance_setups = dict() self._unbound_setups = dict() self._run_setups = dict() self._rules = [] @property def class_setups(self) -> tp.Dict[int, "CAClassSetup"]: """Dict of registered `CAClassSetup` instances by their hash.""" return self._class_setups @property def instance_setups(self) -> tp.Dict[int, "CAInstanceSetup"]: """Dict of registered `CAInstanceSetup` instances by their hash.""" return self._instance_setups @property def unbound_setups(self) -> tp.Dict[int, "CAUnboundSetup"]: """Dict of registered `CAUnboundSetup` instances by their hash.""" return self._unbound_setups @property def run_setups(self) -> tp.Dict[int, "CARunSetup"]: """Dict of registered `CARunSetup` instances by their hash.""" return self._run_setups @property def setups(self) -> tp.Dict[int, "CABaseSetup"]: """Dict of registered `CABaseSetup` instances by their hash.""" return {**self.class_setups, **self.instance_setups, **self.unbound_setups, **self.run_setups} @property def rules(self) -> tp.List[CARule]: """List of registered `CARule` instances.""" return self._rules def get_setup_by_hash(self, hash_: int) -> tp.Optional["CABaseSetup"]: """Get the setup by its hash.""" if hash_ in self.class_setups: return self.class_setups[hash_] if hash_ in self.instance_setups: return self.instance_setups[hash_] if hash_ in self.unbound_setups: return self.unbound_setups[hash_] if hash_ in self.run_setups: return self.run_setups[hash_] return None def setup_registered(self, setup: "CABaseSetup") -> bool: """Return whether the setup is registered.""" return self.get_setup_by_hash(hash(setup)) is not None def register_setup(self, setup: "CABaseSetup") -> None: """Register a new setup of type `CABaseSetup`.""" if isinstance(setup, CARunSetup): setups = self.run_setups elif isinstance(setup, CAUnboundSetup): setups = self.unbound_setups elif isinstance(setup, CAInstanceSetup): setups = self.instance_setups elif isinstance(setup, CAClassSetup): setups = self.class_setups else: raise TypeError(str(type(setup))) setups[hash(setup)] = setup def deregister_setup(self, setup: "CABaseSetup") -> None: """Deregister a new setup of type `CABaseSetup`. Removes the setup from its respective collection. To also deregister its children, call the `CASetupDelegatorMixin.deregister` method.""" if isinstance(setup, CARunSetup): setups = self.run_setups elif isinstance(setup, CAUnboundSetup): setups = self.unbound_setups elif isinstance(setup, CAInstanceSetup): setups = self.instance_setups elif isinstance(setup, CAClassSetup): setups = self.class_setups else: raise TypeError(str(type(setup))) if hash(setup) in setups: del setups[hash(setup)] def register_rule(self, rule: CARule) -> None: """Register a new rule of type `CARule`.""" self.rules.append(rule) def deregister_rule(self, rule: CARule) -> None: """Deregister a rule of type `CARule`.""" self.rules.remove(rule) def get_run_setup( self, cacheable: cacheableT, instance: tp.Optional[Cacheable] = None, ) -> tp.Optional["CARunSetup"]: """Get a setup of type `CARunSetup` with this cacheable and instance, or return None.""" run_setup = self.run_setups.get(CARunSetup.get_hash(cacheable, instance=instance), None) if run_setup is not None and run_setup.instance_obj is _GARBAGE: self.deregister_setup(run_setup) return None return run_setup def get_unbound_setup(self, cacheable: cacheableT) -> tp.Optional["CAUnboundSetup"]: """Get a setup of type `CAUnboundSetup` with this cacheable or return None.""" return self.unbound_setups.get(CAUnboundSetup.get_hash(cacheable), None) def get_instance_setup(self, instance: Cacheable) -> tp.Optional["CAInstanceSetup"]: """Get a setup of type `CAInstanceSetup` with this instance or return None.""" instance_setup = self.instance_setups.get(CAInstanceSetup.get_hash(instance), None) if instance_setup is not None and instance_setup.instance_obj is _GARBAGE: self.deregister_setup(instance_setup) return None return instance_setup def get_class_setup(self, cls: tp.Type[Cacheable]) -> tp.Optional["CAClassSetup"]: """Get a setup of type `CAInstanceSetup` with this class or return None.""" return self.class_setups.get(CAClassSetup.get_hash(cls), None) def match_setups( self, query_like: tp.MaybeIterable[tp.Any] = None, collapse: bool = False, kind: tp.Optional[tp.MaybeIterable[str]] = None, exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None, exclude_children: bool = True, filter_func: tp.Optional[tp.Callable] = None, ) -> tp.Set["CABaseSetup"]: """Match all setups registered in this registry against `query_like`. `query_like` can be one or more query-like objects that will be parsed using `CAQuery.parse`. Set `collapse` to True to remove child setups that belong to any matched parent setup. `kind` can be one or multiple of the following: * 'class' to only return class setups (instances of `CAClassSetup`) * 'instance' to only return instance setups (instances of `CAInstanceSetup`) * 'unbound' to only return unbound setups (instances of `CAUnboundSetup`) * 'runnable' to only return runnable setups (instances of `CARunSetup`) Set `exclude` to one or multiple setups to exclude. To not exclude their children, set `exclude_children` to False. !!! note `exclude_children` is applied only when `collapse` is True. `filter_func` can be used to filter out setups. For example, `lambda setup: setup.caching_enabled` includes only those setups that have caching enabled. It must take a setup and return a boolean of whether to include this setup in the final results.""" if not checks.is_iterable(query_like) or isinstance(query_like, (str, tuple)): query_like = [query_like] query_like = list(map(CAQuery.parse, query_like)) if kind is None: kind = {"class", "instance", "unbound", "runnable"} if exclude is None: exclude = set() if isinstance(exclude, CABaseSetup): exclude = {exclude} else: exclude = set(exclude) matches = set() if not collapse: if isinstance(kind, str): if kind.lower() == "class": setups = set(self.class_setups.values()) elif kind.lower() == "instance": setups = set(self.instance_setups.values()) elif kind.lower() == "unbound": setups = set(self.unbound_setups.values()) elif kind.lower() == "runnable": setups = set(self.run_setups.values()) else: raise ValueError(f"kind '{kind}' is not supported") for setup in setups: if setup not in exclude: for q in query_like: if q.matches_setup(setup): if filter_func is None or filter_func(setup): matches.add(setup) break elif checks.is_iterable(kind): matches = set.union( *[ self.match_setups( query_like, kind=k, collapse=collapse, exclude=exclude, exclude_children=exclude_children, filter_func=filter_func, ) for k in kind ] ) else: raise TypeError(f"kind must be either a string or a sequence of strings, not {type(kind)}") else: if isinstance(kind, str): kind = {kind} else: kind = set(kind) collapse_setups = set() if "class" in kind: class_matches = set() for class_setup in self.class_setups.values(): for q in query_like: if q.matches_setup(class_setup): if filter_func is None or filter_func(class_setup): if class_setup not in exclude: class_matches.add(class_setup) if class_setup not in exclude or exclude_children: collapse_setups |= class_setup.child_setups break for class_setup in class_matches: if class_setup not in collapse_setups: matches.add(class_setup) if "instance" in kind: for instance_setup in self.instance_setups.values(): if instance_setup not in collapse_setups: for q in query_like: if q.matches_setup(instance_setup): if filter_func is None or filter_func(instance_setup): if instance_setup not in exclude: matches.add(instance_setup) if instance_setup not in exclude or exclude_children: collapse_setups |= instance_setup.child_setups break if "unbound" in kind: for unbound_setup in self.unbound_setups.values(): if unbound_setup not in collapse_setups: for q in query_like: if q.matches_setup(unbound_setup): if filter_func is None or filter_func(unbound_setup): if unbound_setup not in exclude: matches.add(unbound_setup) if unbound_setup not in exclude or exclude_children: collapse_setups |= unbound_setup.child_setups break if "runnable" in kind: for run_setup in self.run_setups.values(): if run_setup not in collapse_setups: for q in query_like: if q.matches_setup(run_setup): if filter_func is None or filter_func(run_setup): if run_setup not in exclude: matches.add(run_setup) break return matches ca_reg = CacheableRegistry() """Default registry of type `CacheableRegistry`.""" class CAMetrics(Base): """Abstract class that exposes various metrics related to caching.""" @property def hits(self) -> int: """Number of hits.""" raise NotImplementedError @property def misses(self) -> int: """Number of misses.""" raise NotImplementedError @property def total_size(self) -> int: """Total size of cached objects.""" raise NotImplementedError @property def total_elapsed(self) -> tp.Optional[timedelta]: """Total number of seconds elapsed during running the function.""" raise NotImplementedError @property def total_saved(self) -> tp.Optional[timedelta]: """Total number of seconds saved by using the cache.""" raise NotImplementedError @property def first_run_time(self) -> tp.Optional[datetime]: """Time of the first run.""" raise NotImplementedError @property def last_run_time(self) -> tp.Optional[datetime]: """Time of the last run.""" raise NotImplementedError @property def first_hit_time(self) -> tp.Optional[datetime]: """Time of the first hit.""" raise NotImplementedError @property def last_hit_time(self) -> tp.Optional[datetime]: """Time of the last hit.""" raise NotImplementedError @property def metrics(self) -> dict: """Dict with all metrics.""" return dict( hits=self.hits, misses=self.misses, total_size=self.total_size, total_elapsed=self.total_elapsed, total_saved=self.total_saved, first_run_time=self.first_run_time, last_run_time=self.last_run_time, first_hit_time=self.first_hit_time, last_hit_time=self.last_hit_time, ) @define class CABaseSetup(CAMetrics, DefineMixin): """Base class that exposes properties and methods for cache management.""" registry: CacheableRegistry = define.field(default=ca_reg) """Registry of type `CacheableRegistry`.""" use_cache: tp.Optional[bool] = define.field(default=None) """Whether caching is enabled.""" whitelist: tp.Optional[bool] = define.field(default=None) """Whether to cache even if caching was disabled globally.""" active: bool = define.field(default=True) """Whether to register and/or return setup when requested.""" def __attrs_post_init__(self) -> None: object.__setattr__(self, "_creation_time", datetime.now(timezone.utc)) object.__setattr__(self, "_use_cache_lut", None) object.__setattr__(self, "_whitelist_lut", None) @property def query(self) -> CAQuery: """Query to match this setup.""" raise NotImplementedError @property def caching_enabled(self) -> tp.Optional[bool]: """Whether caching is enabled in this setup. Caching is disabled when any of the following apply: * `CARunSetup.use_cache` is False * Caching is disabled globally and `CARunSetup.whitelist` is False * Caching and whitelisting are disabled globally Returns None if `CABaseSetup.use_cache` or `CABaseSetup.whitelist` is None.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] if self.use_cache is None: return None if self.use_cache: if not caching_cfg["disable"]: return True if not caching_cfg["disable_whitelist"]: if self.whitelist is None: return None if self.whitelist: return True return False def register(self) -> None: """Register setup using `CacheableRegistry.register_setup`.""" self.registry.register_setup(self) def deregister(self) -> None: """Register setup using `CacheableRegistry.deregister_setup`.""" self.registry.deregister_setup(self) @property def registered(self) -> bool: """Return whether setup is registered.""" return self.registry.setup_registered(self) def enforce_rules(self) -> None: """Enforce registry rules.""" for rule in self.registry.rules: rule.enforce(self) def activate(self) -> None: """Activate.""" object.__setattr__(self, "active", True) def deactivate(self) -> None: """Deactivate.""" object.__setattr__(self, "active", False) def enable_whitelist(self) -> None: """Enable whitelisting.""" object.__setattr__(self, "whitelist", True) object.__setattr__(self, "_whitelist_lut", datetime.now(timezone.utc)) def disable_whitelist(self) -> None: """Disable whitelisting.""" object.__setattr__(self, "whitelist", False) object.__setattr__(self, "_whitelist_lut", datetime.now(timezone.utc)) def enable_caching(self, force: bool = False, silence_warnings: tp.Optional[bool] = None) -> None: """Enable caching. Set `force` to True to whitelist this setup.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] if silence_warnings is None: silence_warnings = caching_cfg["silence_warnings"] object.__setattr__(self, "use_cache", True) if force: object.__setattr__(self, "whitelist", True) else: if caching_cfg["disable"] and not caching_cfg["disable_whitelist"] and not silence_warnings: warn("This operation has no effect: caching is disabled globally and this setup is not whitelisted") if caching_cfg["disable"] and caching_cfg["disable_whitelist"] and not silence_warnings: warn("This operation has no effect: caching and whitelisting are disabled globally") object.__setattr__(self, "_use_cache_lut", datetime.now(timezone.utc)) def disable_caching(self, clear_cache: bool = True) -> None: """Disable caching. Set `clear_cache` to True to also clear the cache.""" object.__setattr__(self, "use_cache", False) if clear_cache: self.clear_cache() object.__setattr__(self, "_use_cache_lut", datetime.now(timezone.utc)) @property def creation_time(self) -> tp.datetime: """Time when this setup was created.""" return object.__getattribute__(self, "_creation_time") @property def use_cache_lut(self) -> tp.Optional[datetime]: """Last time `CABaseSetup.use_cache` was updated.""" return object.__getattribute__(self, "_use_cache_lut") @property def whitelist_lut(self) -> tp.Optional[datetime]: """Last time `CABaseSetup.whitelist` was updated.""" return object.__getattribute__(self, "_whitelist_lut") @property def last_update_time(self) -> tp.Optional[datetime]: """Last time any of `CABaseSetup.use_cache` and `CABaseSetup.whitelist` were updated.""" if self.use_cache_lut is None: return self.whitelist_lut elif self.whitelist_lut is None: return self.use_cache_lut elif self.use_cache_lut is None and self.whitelist_lut is None: return None return max(self.use_cache_lut, self.whitelist_lut) def clear_cache(self) -> None: """Clear the cache.""" raise NotImplementedError @property def same_type_setups(self) -> ValuesView: """Setups of the same type.""" raise NotImplementedError @property def short_str(self) -> str: """Convert this setup into a short string.""" raise NotImplementedError @property def readable_name(self) -> str: """Get a readable name of the object the setup is bound to.""" raise NotImplementedError @property def position_among_similar(self) -> tp.Optional[int]: """Get position among all similar setups. Ordered by creation time.""" i = 0 for setup in self.same_type_setups: if self is setup: return i if setup.readable_name == self.readable_name: i += 1 return None @property def readable_str(self) -> str: """Convert this setup into a readable string.""" return f"{self.readable_name}:{self.position_among_similar}" def get_stats(self, readable: bool = True, short_str: bool = False) -> dict: """Get stats of the setup as a dict with metrics.""" if short_str: string = self.short_str else: string = str(self) total_size = self.total_size total_elapsed = self.total_elapsed total_saved = self.total_saved first_run_time = self.first_run_time last_run_time = self.last_run_time first_hit_time = self.first_hit_time last_hit_time = self.last_hit_time creation_time = self.creation_time last_update_time = self.last_update_time if readable: string = self.readable_str total_size = humanize.naturalsize(total_size) if total_elapsed is not None: minimum_unit = "seconds" if total_elapsed.total_seconds() >= 1 else "milliseconds" total_elapsed = humanize.precisedelta(total_elapsed, minimum_unit) if total_saved is not None: minimum_unit = "seconds" if total_saved.total_seconds() >= 1 else "milliseconds" total_saved = humanize.precisedelta(total_saved, minimum_unit) if first_run_time is not None: first_run_time = humanize.naturaltime( dt.to_naive_datetime(first_run_time), when=dt.to_naive_datetime(datetime.now(timezone.utc)), ) if last_run_time is not None: last_run_time = humanize.naturaltime( dt.to_naive_datetime(last_run_time), when=dt.to_naive_datetime(datetime.now(timezone.utc)), ) if first_hit_time is not None: first_hit_time = humanize.naturaltime( dt.to_naive_datetime(first_hit_time), when=dt.to_naive_datetime(datetime.now(timezone.utc)), ) if last_hit_time is not None: last_hit_time = humanize.naturaltime( dt.to_naive_datetime(last_hit_time), when=dt.to_naive_datetime(datetime.now(timezone.utc)), ) if creation_time is not None: creation_time = humanize.naturaltime( dt.to_naive_datetime(creation_time), when=dt.to_naive_datetime(datetime.now(timezone.utc)), ) if last_update_time is not None: last_update_time = humanize.naturaltime( dt.to_naive_datetime(last_update_time), when=dt.to_naive_datetime(datetime.now(timezone.utc)), ) return dict( hash=hash(self), string=string, use_cache=self.use_cache, whitelist=self.whitelist, caching_enabled=self.caching_enabled, hits=self.hits, misses=self.misses, total_size=total_size, total_elapsed=total_elapsed, total_saved=total_saved, first_run_time=first_run_time, last_run_time=last_run_time, first_hit_time=first_hit_time, last_hit_time=last_hit_time, creation_time=creation_time, last_update_time=last_update_time, ) class CASetupDelegatorMixin(CAMetrics): """Mixin class that delegates cache management to child setups.""" @property def child_setups(self) -> tp.Set[CABaseSetup]: """Child setups.""" raise NotImplementedError def get_setup_hierarchy(self, readable: bool = True, short_str: bool = False) -> tp.List[dict]: """Get the setup hierarchy by recursively traversing the child setups.""" results = [] for setup in self.child_setups: if readable: setup_obj = setup.readable_str elif short_str: setup_obj = setup.short_str else: setup_obj = setup if isinstance(setup, CASetupDelegatorMixin): results.append(dict(parent=setup_obj, children=setup.get_setup_hierarchy(readable=readable))) else: results.append(setup_obj) return results def delegate( self, func: tp.Callable, exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None, **kwargs, ) -> None: """Delegate a function to all child setups. `func` must take the setup and return nothing. If the setup is an instance of `CASetupDelegatorMixin`, it must additionally accept `exclude`.""" if exclude is None: exclude = set() if isinstance(exclude, CABaseSetup): exclude = {exclude} else: exclude = set(exclude) for setup in self.child_setups: if setup not in exclude: if isinstance(setup, CASetupDelegatorMixin): func(setup, exclude=exclude, **kwargs) else: func(setup, **kwargs) def deregister(self, **kwargs) -> None: """Calls `CABaseSetup.deregister` on each child setup.""" self.delegate(lambda setup, **_kwargs: setup.deregister(**_kwargs), **kwargs) def enable_whitelist(self, **kwargs) -> None: """Calls `CABaseSetup.enable_whitelist` on each child setup.""" self.delegate(lambda setup, **_kwargs: setup.enable_whitelist(**_kwargs), **kwargs) def disable_whitelist(self, **kwargs) -> None: """Calls `CABaseSetup.disable_whitelist` on each child setup.""" self.delegate(lambda setup, **_kwargs: setup.disable_whitelist(**_kwargs), **kwargs) def enable_caching(self, **kwargs) -> None: """Calls `CABaseSetup.enable_caching` on each child setup.""" self.delegate(lambda setup, **_kwargs: setup.enable_caching(**_kwargs), **kwargs) def disable_caching(self, **kwargs) -> None: """Calls `CABaseSetup.disable_caching` on each child setup.""" self.delegate(lambda setup, **_kwargs: setup.disable_caching(**_kwargs), **kwargs) def clear_cache(self, **kwargs) -> None: """Calls `CABaseSetup.clear_cache` on each child setup.""" self.delegate(lambda setup, **_kwargs: setup.clear_cache(**_kwargs), **kwargs) @property def hits(self) -> int: return sum([setup.hits for setup in self.child_setups]) @property def misses(self) -> int: return sum([setup.misses for setup in self.child_setups]) @property def total_size(self) -> int: return sum([setup.total_size for setup in self.child_setups]) @property def total_elapsed(self) -> tp.Optional[timedelta]: total_elapsed = None for setup in self.child_setups: elapsed = setup.total_elapsed if elapsed is not None: if total_elapsed is None: total_elapsed = elapsed else: total_elapsed += elapsed return total_elapsed @property def total_saved(self) -> tp.Optional[timedelta]: total_saved = None for setup in self.child_setups: saved = setup.total_saved if saved is not None: if total_saved is None: total_saved = saved else: total_saved += saved return total_saved @property def first_run_time(self) -> tp.Optional[datetime]: first_run_times = [] for setup in self.child_setups: first_run_time = setup.first_run_time if first_run_time is not None: first_run_times.append(first_run_time) if len(first_run_times) == 0: return None return list(sorted(first_run_times))[0] @property def last_run_time(self) -> tp.Optional[datetime]: last_run_times = [] for setup in self.child_setups: last_run_time = setup.last_run_time if last_run_time is not None: last_run_times.append(last_run_time) if len(last_run_times) == 0: return None return list(sorted(last_run_times))[-1] @property def first_hit_time(self) -> tp.Optional[datetime]: first_hit_times = [] for setup in self.child_setups: first_hit_time = setup.first_hit_time if first_hit_time is not None: first_hit_times.append(first_hit_time) if len(first_hit_times) == 0: return None return list(sorted(first_hit_times))[0] @property def last_hit_time(self) -> tp.Optional[datetime]: last_hit_times = [] for setup in self.child_setups: last_hit_time = setup.last_hit_time if last_hit_time is not None: last_hit_times.append(last_hit_time) if len(last_hit_times) == 0: return None return list(sorted(last_hit_times))[-1] def get_stats( self, readable: bool = True, short_str: bool = False, index_by_hash: bool = False, filter_func: tp.Optional[tp.Callable] = None, include: tp.Optional[tp.MaybeSequence[str]] = None, exclude: tp.Optional[tp.MaybeSequence[str]] = None, ) -> tp.Optional[tp.Frame]: """Get a DataFrame out of stats dicts of child setups.""" if len(self.child_setups) == 0: return None df = pd.DataFrame( [ setup.get_stats(readable=readable, short_str=short_str) for setup in self.child_setups if filter_func is None or filter_func(setup) ] ) if index_by_hash: df.set_index("hash", inplace=True) df.index.name = "hash" else: df.set_index("string", inplace=True) df.index.name = "object" if include is not None: if isinstance(include, str): include = [include] columns = include else: columns = df.columns if exclude is not None: if isinstance(exclude, str): exclude = [exclude] columns = [c for c in columns if c not in exclude] if len(columns) == 0: return None return df[columns].sort_index() class CABaseDelegatorSetup(CABaseSetup, CASetupDelegatorMixin): """Base class acting as a stateful setup that delegates cache management to child setups. First delegates the work and only then changes its own state.""" @property def child_setups(self) -> tp.Set[CABaseSetup]: """Get child setups that match `CABaseDelegatorSetup.query`.""" return self.registry.match_setups(self.query, collapse=True) def deregister(self, **kwargs) -> None: CASetupDelegatorMixin.deregister(self, **kwargs) CABaseSetup.deregister(self) def enable_whitelist(self, **kwargs) -> None: CASetupDelegatorMixin.enable_whitelist(self, **kwargs) CABaseSetup.enable_whitelist(self) def disable_whitelist(self, **kwargs) -> None: CASetupDelegatorMixin.disable_whitelist(self, **kwargs) CABaseSetup.disable_whitelist(self) def enable_caching(self, force: bool = False, silence_warnings: tp.Optional[bool] = None, **kwargs) -> None: CASetupDelegatorMixin.enable_caching(self, force=force, silence_warnings=silence_warnings, **kwargs) CABaseSetup.enable_caching(self, force=force, silence_warnings=silence_warnings) def disable_caching(self, clear_cache: bool = True, **kwargs) -> None: CASetupDelegatorMixin.disable_caching(self, clear_cache=clear_cache, **kwargs) CABaseSetup.disable_caching(self, clear_cache=False) def clear_cache(self, **kwargs) -> None: CASetupDelegatorMixin.clear_cache(self, **kwargs) def _assert_value_not_none(instance: object, attribute: attr.Attribute, value: tp.Any) -> None: """Assert that value is not None.""" if value is None: raise ValueError("Please provide {}".format(attribute.name)) CAClassSetupT = tp.TypeVar("CAClassSetupT", bound="CAClassSetup") @define class CAClassSetup(CABaseDelegatorSetup, DefineMixin): """Class that represents a setup of a cacheable class. The provided class must subclass `vectorbtpro.utils.caching.Cacheable`. Delegates cache management to its child subclass setups of type `CAClassSetup` and child instance setups of type `CAInstanceSetup`. If `use_cash` or `whitelist` are None, inherits a non-empty value from its superclass setups using the method resolution order (MRO). !!! note Unbound setups are not children of class setups. See notes on `CAUnboundSetup`.""" cls: tp.Type[Cacheable] = define.field(default=None, validator=_assert_value_not_none) """Cacheable class.""" @staticmethod def get_hash(cls: tp.Type[Cacheable]) -> int: return hash((cls,)) @staticmethod def get_cacheable_superclasses(cls: tp.Type[Cacheable]) -> tp.List[tp.Type[Cacheable]]: """Get an ordered list of the cacheable superclasses of a class.""" superclasses = [] for super_cls in inspect.getmro(cls): if issubclass(super_cls, Cacheable): if super_cls is not cls: superclasses.append(super_cls) return superclasses @staticmethod def get_superclass_setups(registry: CacheableRegistry, cls: tp.Type[Cacheable]) -> tp.List["CAClassSetup"]: """Setups of type `CAClassSetup` of each in `CAClassSetup.get_cacheable_superclasses`.""" setups = [] for super_cls in CAClassSetup.get_cacheable_superclasses(cls): if registry.get_class_setup(super_cls) is not None: setups.append(super_cls.get_ca_setup()) return setups @staticmethod def get_cacheable_subclasses(cls: tp.Type[Cacheable]) -> tp.List[tp.Type[Cacheable]]: """Get an ordered list of the cacheable subclasses of a class.""" subclasses = [] for sub_cls in cls.__subclasses__(): if issubclass(sub_cls, Cacheable): if sub_cls is not cls: subclasses.append(sub_cls) subclasses.extend(CAClassSetup.get_cacheable_subclasses(sub_cls)) return subclasses @staticmethod def get_subclass_setups(registry: CacheableRegistry, cls: tp.Type[Cacheable]) -> tp.List["CAClassSetup"]: """Setups of type `CAClassSetup` of each in `CAClassSetup.get_cacheable_subclasses`.""" setups = [] for super_cls in CAClassSetup.get_cacheable_subclasses(cls): if registry.get_class_setup(super_cls) is not None: setups.append(super_cls.get_ca_setup()) return setups @staticmethod def get_unbound_cacheables(cls: tp.Type[Cacheable]) -> tp.Set[cacheableT]: """Get a set of the unbound cacheables of a class.""" members = inspect.getmembers(cls, is_bindable_cacheable) return {attr for attr_name, attr in members} @staticmethod def get_unbound_setups(registry: CacheableRegistry, cls: tp.Type[Cacheable]) -> tp.Set["CAUnboundSetup"]: """Setups of type `CAUnboundSetup` of each in `CAClassSetup.get_unbound_cacheables`.""" setups = set() for cacheable in CAClassSetup.get_unbound_cacheables(cls): if registry.get_unbound_setup(cacheable) is not None: setups.add(cacheable.get_ca_setup()) return setups @classmethod def get( cls: tp.Type[CAClassSetupT], cls_: tp.Type[Cacheable], registry: CacheableRegistry = ca_reg, **kwargs, ) -> tp.Optional[CAClassSetupT]: """Get setup from `CacheableRegistry` or register a new one. `**kwargs` are passed to `CAClassSetup.__init__`.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] if caching_cfg["disable_machinery"]: return None setup = registry.get_class_setup(cls_) if setup is not None: if not setup.active: return None return setup instance = cls(cls=cls_, registry=registry, **kwargs) instance.enforce_rules() if instance.active: instance.register() return instance def __attrs_post_init__(self) -> None: CABaseSetup.__attrs_post_init__(self) checks.assert_subclass_of(self.cls, Cacheable) use_cache = self.use_cache whitelist = self.whitelist if use_cache is None or whitelist is None: superclass_setups = self.superclass_setups[::-1] for setup in superclass_setups: if use_cache is None: if setup.use_cache is not None: object.__setattr__(self, "use_cache", setup.use_cache) if whitelist is None: if setup.whitelist is not None: object.__setattr__(self, "whitelist", setup.whitelist) @property def query(self) -> CAQuery: return CAQuery(base_cls=self.cls) @property def superclass_setups(self) -> tp.List["CAClassSetup"]: """See `CAClassSetup.get_superclass_setups`.""" return self.get_superclass_setups(self.registry, self.cls) @property def subclass_setups(self) -> tp.List["CAClassSetup"]: """See `CAClassSetup.get_subclass_setups`.""" return self.get_subclass_setups(self.registry, self.cls) @property def unbound_setups(self) -> tp.Set["CAUnboundSetup"]: """See `CAClassSetup.get_unbound_setups`.""" return self.get_unbound_setups(self.registry, self.cls) @property def instance_setups(self) -> tp.Set["CAInstanceSetup"]: """Setups of type `CAInstanceSetup` of instances of the class.""" matches = set() for instance_setup in self.registry.instance_setups.values(): if instance_setup.class_setup is self: matches.add(instance_setup) return matches @property def any_use_cache_lut(self) -> tp.Optional[datetime]: """Last time `CABaseSetup.use_cache` was updated in this class or any of its superclasses.""" max_use_cache_lut = self.use_cache_lut for setup in self.superclass_setups: if setup.use_cache_lut is not None: if max_use_cache_lut is None or setup.use_cache_lut > max_use_cache_lut: max_use_cache_lut = setup.use_cache_lut return max_use_cache_lut @property def any_whitelist_lut(self) -> tp.Optional[datetime]: """Last time `CABaseSetup.whitelist` was updated in this class or any of its superclasses.""" max_whitelist_lut = self.whitelist_lut for setup in self.superclass_setups: if setup.whitelist_lut is not None: if max_whitelist_lut is None or setup.whitelist_lut > max_whitelist_lut: max_whitelist_lut = setup.whitelist_lut return max_whitelist_lut @property def child_setups(self) -> tp.Set[tp.Union["CAClassSetup", "CAInstanceSetup"]]: return set(self.subclass_setups) | self.instance_setups @property def same_type_setups(self) -> ValuesView: return self.registry.class_setups.values() @property def short_str(self) -> str: return f"" @property def readable_name(self) -> str: return self.cls.__name__ @property def hash_key(self) -> tuple: return (self.cls,) CAInstanceSetupT = tp.TypeVar("CAInstanceSetupT", bound="CAInstanceSetup") @define class CAInstanceSetup(CABaseDelegatorSetup, DefineMixin): """Class that represents a setup of an instance that has cacheables bound to it. The provided instance must be of `vectorbtpro.utils.caching.Cacheable`. Delegates cache management to its child setups of type `CARunSetup`. If `use_cash` or `whitelist` are None, inherits a non-empty value from its parent class setup.""" instance: tp.Union[Cacheable, ReferenceType] = define.field(default=None, validator=_assert_value_not_none) """Cacheable instance.""" @staticmethod def get_hash(instance: Cacheable) -> int: return hash((get_obj_id(instance),)) @classmethod def get( cls: tp.Type[CAInstanceSetupT], instance: Cacheable, registry: CacheableRegistry = ca_reg, **kwargs, ) -> tp.Optional[CAInstanceSetupT]: """Get setup from `CacheableRegistry` or register a new one. `**kwargs` are passed to `CAInstanceSetup.__init__`.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] if caching_cfg["disable_machinery"]: return None setup = registry.get_instance_setup(instance) if setup is not None: if not setup.active: return None return setup instance = cls(instance=instance, registry=registry, **kwargs) instance.enforce_rules() if instance.active: instance.register() return instance def __attrs_post_init__(self) -> None: CABaseSetup.__attrs_post_init__(self) if not isinstance(self.instance, ReferenceType): checks.assert_instance_of(self.instance, Cacheable) instance_ref = ref(self.instance, lambda ref: self.registry.deregister_setup(self)) object.__setattr__(self, "instance", instance_ref) if self.use_cache is None or self.whitelist is None: class_setup = self.class_setup if self.use_cache is None: if class_setup.use_cache is not None: object.__setattr__(self, "use_cache", class_setup.use_cache) if self.whitelist is None: if class_setup.whitelist is not None: object.__setattr__(self, "whitelist", class_setup.whitelist) @property def query(self) -> CAQuery: return CAQuery(instance=self.instance_obj) @property def instance_obj(self) -> tp.Union[Cacheable, object]: """Instance object.""" if self.instance() is None: return _GARBAGE return self.instance() @property def contains_garbage(self) -> bool: """Whether instance was destroyed.""" return self.instance_obj is _GARBAGE @property def class_setup(self) -> tp.Optional[CAClassSetup]: """Setup of type `CAClassSetup` of the cacheable class of the instance.""" if self.contains_garbage: return None return CAClassSetup.get(type(self.instance_obj), self.registry) @property def unbound_setups(self) -> tp.Set["CAUnboundSetup"]: """Setups of type `CAUnboundSetup` of unbound cacheables declared in the class of the instance.""" if self.contains_garbage: return set() return self.class_setup.unbound_setups @property def run_setups(self) -> tp.Set["CARunSetup"]: """Setups of type `CARunSetup` of cacheables bound to the instance.""" if self.contains_garbage: return set() matches = set() for run_setup in self.registry.run_setups.values(): if run_setup.instance_setup is self: matches.add(run_setup) return matches @property def child_setups(self) -> tp.Set["CARunSetup"]: return self.run_setups @property def same_type_setups(self) -> ValuesView: return self.registry.instance_setups.values() @property def short_str(self) -> str: if self.contains_garbage: return "" return f"" @property def readable_name(self) -> str: if self.contains_garbage: return "_GARBAGE" return type(self.instance_obj).__name__.lower() @property def hash_key(self) -> tuple: return (get_obj_id(self.instance_obj),) CAUnboundSetupT = tp.TypeVar("CAUnboundSetupT", bound="CAUnboundSetup") @define class CAUnboundSetup(CABaseDelegatorSetup, DefineMixin): """Class that represents a setup of an unbound cacheable property or method. An unbound callable is a callable that was declared in a class but is not bound to any instance (just yet). !!! note Unbound callables are just regular functions - they have no parent setups. Even though they are formally declared in a class, there is no easy way to get a reference to the class from the decorator itself. Thus, searching for child setups of a specific class won't return unbound setups. Delegates cache management to its child setups of type `CARunSetup`. One unbound cacheable property or method can be bound to multiple instances, thus there is one-to-many relationship between `CAUnboundSetup` and `CARunSetup` instances. !!! hint Use class attributes instead of instance attributes to access unbound callables.""" cacheable: cacheableT = define.field(default=None, validator=_assert_value_not_none) """Cacheable object.""" @staticmethod def get_hash(cacheable: cacheableT) -> int: return hash((cacheable,)) @classmethod def get( cls: tp.Type[CAUnboundSetupT], cacheable: cacheableT, registry: CacheableRegistry = ca_reg, **kwargs, ) -> tp.Optional[CAUnboundSetupT]: """Get setup from `CacheableRegistry` or register a new one. `**kwargs` are passed to `CAUnboundSetup.__init__`.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] if caching_cfg["disable_machinery"]: return None setup = registry.get_unbound_setup(cacheable) if setup is not None: if not setup.active: return None return setup instance = cls(cacheable=cacheable, registry=registry, **kwargs) instance.enforce_rules() if instance.active: instance.register() return instance def __attrs_post_init__(self) -> None: CABaseSetup.__attrs_post_init__(self) if not is_bindable_cacheable(self.cacheable): raise TypeError("cacheable must be either cacheable_property or cacheable_method") @property def query(self) -> CAQuery: return CAQuery(cacheable=self.cacheable) @property def run_setups(self) -> tp.Set["CARunSetup"]: """Setups of type `CARunSetup` of bound cacheables.""" matches = set() for run_setup in self.registry.run_setups.values(): if run_setup.unbound_setup is self: matches.add(run_setup) return matches @property def child_setups(self) -> tp.Set["CARunSetup"]: return self.run_setups @property def same_type_setups(self) -> ValuesView: return self.registry.unbound_setups.values() @property def short_str(self) -> str: if is_cacheable_property(self.cacheable): return f"" return f"" @property def readable_name(self) -> str: if is_cacheable_property(self.cacheable): return f"{self.cacheable.func.__name__}" return f"{self.cacheable.func.__name__}()" @property def hash_key(self) -> tuple: return (self.cacheable,) CARunSetupT = tp.TypeVar("CARunSetupT", bound="CARunSetup") @define class CARunResult(DefineMixin): """Class that represents a cached result of a run. !!! note Hashed solely by the hash of the arguments `args_hash`.""" args_hash: int = define.field() """Hash of the arguments.""" result: tp.Any = define.field() """Result of the run.""" timer: Timer = define.field() """Timer used to measure the execution time.""" def __attrs_post_init__(self) -> None: object.__setattr__(self, "_run_time", datetime.now(timezone.utc)) object.__setattr__(self, "_hits", 0) object.__setattr__(self, "_first_hit_time", None) object.__setattr__(self, "_last_hit_time", None) @staticmethod def get_hash(args_hash: int) -> int: return hash((args_hash,)) @property def result_size(self) -> int: """Get size of the result in memory.""" return sys.getsizeof(self.result) @property def run_time(self) -> datetime: """Time of the run.""" return object.__getattribute__(self, "_run_time") @property def hits(self) -> int: """Number of hits.""" return object.__getattribute__(self, "_hits") @property def first_hit_time(self) -> tp.Optional[datetime]: """Time of the first hit.""" return object.__getattribute__(self, "_first_hit_time") @property def last_hit_time(self) -> tp.Optional[datetime]: """Time of the last hit.""" return object.__getattribute__(self, "_last_hit_time") def hit(self) -> tp.Any: """Hit the result.""" hit_time = datetime.now(timezone.utc) if self.first_hit_time is None: object.__setattr__(self, "_first_hit_time", hit_time) object.__setattr__(self, "_last_hit_time", hit_time) object.__setattr__(self, "_hits", self.hits + 1) return self.result @property def hash_key(self) -> tuple: return (self.args_hash,) @define class CARunSetup(CABaseSetup, DefineMixin): """Class that represents a runnable cacheable setup. Takes care of running functions and caching the results using `CARunSetup.run`. Accepts as `cacheable` either `vectorbtpro.utils.decorators.cacheable_property`, `vectorbtpro.utils.decorators.cacheable_method`, or `vectorbtpro.utils.decorators.cacheable`. Hashed by the callable and optionally the id of the instance its bound to. This way, it can be uniquely identified among all setups. !!! note Cacheable properties and methods must provide an instance. Only one instance per each unique combination of `cacheable` and `instance` can exist at a time. If `use_cash` or `whitelist` are None, inherits a non-empty value either from its parent instance setup or its parent unbound setup. If both setups have non-empty values, takes the one that has been updated more recently. !!! note Use `CARunSetup.get` class method instead of `CARunSetup.__init__` to create a setup. The class method first checks whether a setup with the same hash has already been registered, and if so, returns it. Otherwise, creates and registers a new one. Using `CARunSetup.__init__` will throw an error if there is a setup with the same hash.""" cacheable: cacheableT = define.field(default=None, validator=_assert_value_not_none) """Cacheable object.""" instance: tp.Union[Cacheable, ReferenceType] = define.field(default=None) """Cacheable instance.""" max_size: tp.Optional[int] = define.field(default=None) """Maximum number of entries in `CARunSetup.cache`.""" ignore_args: tp.Optional[tp.Iterable[tp.AnnArgQuery]] = define.field(default=None) """Arguments to ignore when hashing.""" cache: tp.Dict[int, CARunResult] = define.field(factory=dict) """Dict of cached `CARunResult` instances by their hash.""" @staticmethod def get_hash(cacheable: cacheableT, instance: tp.Optional[Cacheable] = None) -> int: return hash((cacheable, get_obj_id(instance) if instance is not None else None)) @classmethod def get( cls: tp.Type[CARunSetupT], cacheable: cacheableT, instance: tp.Optional[Cacheable] = None, registry: CacheableRegistry = ca_reg, **kwargs, ) -> tp.Optional[CARunSetupT]: """Get setup from `CacheableRegistry` or register a new one. `**kwargs` are passed to `CARunSetup.__init__`.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] if caching_cfg["disable_machinery"]: return None setup = registry.get_run_setup(cacheable, instance=instance) if setup is not None: if not setup.active: return None return setup instance = cls(cacheable=cacheable, instance=instance, registry=registry, **kwargs) instance.enforce_rules() if instance.active: instance.register() return instance def __attrs_post_init__(self) -> None: CABaseSetup.__attrs_post_init__(self) if not is_cacheable(self.cacheable): raise TypeError("cacheable must be either cacheable_property, cacheable_method, or cacheable") if self.instance is None: if is_cacheable_property(self.cacheable): raise ValueError("CARunSetup requires an instance for cacheable_property") elif is_cacheable_method(self.cacheable): raise ValueError("CARunSetup requires an instance for cacheable_method") else: checks.assert_instance_of(self.instance, Cacheable) if is_cacheable_function(self.cacheable): raise ValueError("Cacheable functions can't have an instance") if self.instance is not None and not isinstance(self.instance, ReferenceType): checks.assert_instance_of(self.instance, Cacheable) instance_ref = ref(self.instance, lambda ref: self.registry.deregister_setup(self)) object.__setattr__(self, "instance", instance_ref) if self.use_cache is None or self.whitelist is None: instance_setup = self.instance_setup unbound_setup = self.unbound_setup if self.use_cache is None: if ( instance_setup is not None and unbound_setup is not None and instance_setup.use_cache is not None and unbound_setup.use_cache is not None ): if unbound_setup.use_cache_lut is not None and ( instance_setup.class_setup.any_use_cache_lut is None or unbound_setup.use_cache_lut > instance_setup.class_setup.any_use_cache_lut ): # Unbound setup was updated more recently than any superclass setup object.__setattr__(self, "use_cache", unbound_setup.use_cache) else: object.__setattr__(self, "use_cache", instance_setup.use_cache) elif instance_setup is not None and instance_setup.use_cache is not None: object.__setattr__(self, "use_cache", instance_setup.use_cache) elif unbound_setup is not None and unbound_setup.use_cache is not None: object.__setattr__(self, "use_cache", unbound_setup.use_cache) if self.whitelist is None: if ( instance_setup is not None and unbound_setup is not None and instance_setup.whitelist is not None and unbound_setup.whitelist is not None ): if unbound_setup.whitelist_lut is not None and ( instance_setup.class_setup.any_whitelist_lut is None or unbound_setup.whitelist_lut > instance_setup.class_setup.any_whitelist_lut ): # Unbound setup was updated more recently than any superclass setup object.__setattr__(self, "whitelist", unbound_setup.whitelist) else: object.__setattr__(self, "whitelist", instance_setup.whitelist) elif instance_setup is not None and instance_setup.whitelist is not None: object.__setattr__(self, "whitelist", instance_setup.whitelist) elif unbound_setup is not None and unbound_setup.whitelist is not None: object.__setattr__(self, "whitelist", unbound_setup.whitelist) @property def query(self) -> CAQuery: return CAQuery(cacheable=self.cacheable, instance=self.instance_obj) @property def instance_obj(self) -> tp.Union[Cacheable, object]: """Instance object.""" if self.instance is not None and self.instance() is None: return _GARBAGE return self.instance() if self.instance is not None else None @property def contains_garbage(self) -> bool: """Whether instance was destroyed.""" return self.instance_obj is _GARBAGE @property def instance_setup(self) -> tp.Optional[CAInstanceSetup]: """Setup of type `CAInstanceSetup` of the instance this cacheable is bound to.""" if self.instance_obj is None or self.contains_garbage: return None return CAInstanceSetup.get(self.instance_obj, self.registry) @property def unbound_setup(self) -> tp.Optional[CAUnboundSetup]: """Setup of type `CAUnboundSetup` of the unbound cacheable.""" return self.registry.get_unbound_setup(self.cacheable) @property def hits(self) -> int: return sum([run_result.hits for run_result in self.cache.values()]) @property def misses(self) -> int: return len(self.cache) @property def total_size(self) -> int: return sum([run_result.result_size for run_result in self.cache.values()]) @property def total_elapsed(self) -> tp.Optional[timedelta]: total_elapsed = None for run_result in self.cache.values(): elapsed = run_result.timer.elapsed(readable=False) if total_elapsed is None: total_elapsed = elapsed else: total_elapsed += elapsed return total_elapsed @property def total_saved(self) -> tp.Optional[timedelta]: total_saved = None for run_result in self.cache.values(): saved = run_result.timer.elapsed(readable=False) * run_result.hits if total_saved is None: total_saved = saved else: total_saved += saved return total_saved @property def first_run_time(self) -> tp.Optional[datetime]: if len(self.cache) == 0: return None return list(self.cache.values())[0].run_time @property def last_run_time(self) -> tp.Optional[datetime]: if len(self.cache) == 0: return None return list(self.cache.values())[-1].run_time @property def first_hit_time(self) -> tp.Optional[datetime]: first_hit_times = [] for run_result in self.cache.values(): if run_result.first_hit_time is not None: first_hit_times.append(run_result.first_hit_time) if len(first_hit_times) == 0: return None return list(sorted(first_hit_times))[0] @property def last_hit_time(self) -> tp.Optional[datetime]: last_hit_times = [] for run_result in self.cache.values(): if run_result.last_hit_time is not None: last_hit_times.append(run_result.last_hit_time) if len(last_hit_times) == 0: return None return list(sorted(last_hit_times))[-1] def run_func(self, *args, **kwargs) -> tp.Any: """Run the setup's function without caching.""" if self.instance_obj is not None: return self.cacheable.func(self.instance_obj, *args, **kwargs) return self.cacheable.func(*args, **kwargs) def get_args_hash(self, *args, **kwargs) -> tp.Optional[int]: """Get the hash of the passed arguments. `CARunSetup.ignore_args` gets extended with `ignore_args` under `vectorbtpro._settings.caching`. If no arguments were passed, hashes None.""" if len(args) == 0 and len(kwargs) == 0: return hash(None) from vectorbtpro._settings import settings caching_cfg = settings["caching"] ignore_args = list(caching_cfg["ignore_args"]) if self.ignore_args is not None: ignore_args.extend(list(self.ignore_args)) return hash_args( self.cacheable.func, args if self.instance_obj is None else (get_obj_id(self.instance_obj), *args), kwargs, ignore_args=ignore_args, ) def run_func_and_cache(self, *args, **kwargs) -> tp.Any: """Run the setup's function and cache the result. Hashes the arguments using `CARunSetup.get_args_hash`, runs the function using `CARunSetup.run_func`, wraps the result using `CARunResult`, and uses the hash as a key to store the instance of `CARunResult` into `CARunSetup.cache` for later retrieval.""" args_hash = self.get_args_hash(*args, **kwargs) run_result_hash = CARunResult.get_hash(args_hash) if run_result_hash in self.cache: return self.cache[run_result_hash].hit() if self.max_size is not None and self.max_size <= len(self.cache): del self.cache[list(self.cache.keys())[0]] with Timer() as timer: result = self.run_func(*args, **kwargs) run_result = CARunResult(args_hash, result, timer=timer) self.cache[run_result_hash] = run_result return result def run(self, *args, **kwargs) -> tp.Any: """Run the setup and cache it depending on a range of conditions. Runs `CARunSetup.run_func` if caching is disabled or arguments are not hashable, and `CARunSetup.run_func_and_cache` otherwise.""" if self.caching_enabled: try: return self.run_func_and_cache(*args, **kwargs) except UnhashableArgsError: pass return self.run_func(*args, **kwargs) def clear_cache(self) -> None: """Clear the cache.""" self.cache.clear() @property def same_type_setups(self) -> ValuesView: return self.registry.run_setups.values() @property def short_str(self) -> str: if self.contains_garbage: return "" if is_cacheable_property(self.cacheable): return ( f"" ) if is_cacheable_method(self.cacheable): return ( f"" ) return f"" @property def readable_name(self) -> str: if self.contains_garbage: return "_GARBAGE" if is_cacheable_property(self.cacheable): return f"{type(self.instance_obj).__name__.lower()}.{self.cacheable.func.__name__}" if is_cacheable_method(self.cacheable): return f"{type(self.instance_obj).__name__.lower()}.{self.cacheable.func.__name__}()" return f"{self.cacheable.__name__}()" @property def readable_str(self) -> str: if self.contains_garbage: return f"_GARBAGE:{self.position_among_similar}" if is_cacheable_property(self.cacheable): return ( f"{type(self.instance_obj).__name__.lower()}:" f"{self.instance_setup.position_among_similar}." f"{self.cacheable.func.__name__}" ) if is_cacheable_method(self.cacheable): return ( f"{type(self.instance_obj).__name__.lower()}:" f"{self.instance_setup.position_among_similar}." f"{self.cacheable.func.__name__}()" ) return f"{self.cacheable.__name__}():{self.position_among_similar}" @property def hash_key(self) -> tuple: return self.cacheable, get_obj_id(self.instance_obj) if self.instance_obj is not None else None class CAQueryDelegator(CASetupDelegatorMixin): """Class that delegates any setups that match a query. `*args`, `collapse`, and `**kwargs` are passed to `CacheableRegistry.match_setups`.""" def __init__(self, *args, registry: CacheableRegistry = ca_reg, collapse: bool = True, **kwargs) -> None: self._args = args kwargs["collapse"] = collapse self._kwargs = kwargs self._registry = registry @property def args(self) -> tp.Args: """Arguments.""" return self._args @property def kwargs(self) -> tp.Kwargs: """Keyword arguments.""" return self._kwargs @property def registry(self) -> CacheableRegistry: """Registry of type `CacheableRegistry`.""" return self._registry @property def child_setups(self) -> tp.Set[CABaseSetup]: """Get child setups by matching them using `CacheableRegistry.match_setups`.""" return self.registry.match_setups(*self.args, **self.kwargs) def get_cache_stats(*args, **kwargs) -> tp.Optional[tp.Frame]: """Get cache stats globally or of an object.""" delegator_kwargs = {} stats_kwargs = {} if len(kwargs) > 0: overview_arg_names = get_func_arg_names(CAQueryDelegator.get_stats) for k in list(kwargs.keys()): if k in overview_arg_names: stats_kwargs[k] = kwargs.pop(k) else: delegator_kwargs[k] = kwargs.pop(k) else: delegator_kwargs = kwargs return CAQueryDelegator(*args, **delegator_kwargs).get_stats(**stats_kwargs) def print_cache_stats(*args, **kwargs) -> None: """Print cache stats globally or of an object.""" ptable(get_cache_stats(*args, **kwargs)) def clear_cache(*args, **kwargs) -> None: """Clear cache globally or of an object.""" return CAQueryDelegator(*args, **kwargs).clear_cache() def collect_garbage() -> None: """Collect garbage.""" import gc gc.collect() def flush() -> None: """Clear cache and collect garbage.""" clear_cache() collect_garbage() def disable_caching(clear_cache: bool = True) -> None: """Disable caching globally.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] caching_cfg["disable"] = True caching_cfg["disable_whitelist"] = True caching_cfg["disable_machinery"] = True if clear_cache: CAQueryDelegator().clear_cache() def enable_caching() -> None: """Enable caching globally.""" from vectorbtpro._settings import settings caching_cfg = settings["caching"] caching_cfg["disable"] = False caching_cfg["disable_whitelist"] = False caching_cfg["disable_machinery"] = False class CachingDisabled(Base): """Context manager to disable caching.""" def __init__( self, query_like: tp.Optional[tp.Any] = None, use_base_cls: bool = True, kind: tp.Optional[tp.MaybeIterable[str]] = None, exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None, filter_func: tp.Optional[tp.Callable] = None, registry: CacheableRegistry = ca_reg, disable_whitelist: bool = True, disable_machinery: bool = True, clear_cache: bool = True, silence_warnings: bool = False, ) -> None: self._query_like = query_like self._use_base_cls = use_base_cls self._kind = kind self._exclude = exclude self._filter_func = filter_func self._registry = registry self._disable_whitelist = disable_whitelist self._disable_machinery = disable_machinery self._clear_cache = clear_cache self._silence_warnings = silence_warnings self._rule = None self._init_settings = None self._init_setup_settings = None @property def query_like(self) -> tp.Optional[tp.Any]: """See `CAQuery.parse`.""" return self._query_like @property def use_base_cls(self) -> bool: """See `CAQuery.parse`.""" return self._use_base_cls @property def kind(self) -> tp.Optional[tp.MaybeIterable[str]]: """See `CARule.kind`.""" return self._kind @property def exclude(self) -> tp.Optional[tp.MaybeIterable["CABaseSetup"]]: """See `CARule.exclude`.""" return self._exclude @property def filter_func(self) -> tp.Optional[tp.Callable]: """See `CARule.filter_func`.""" return self._filter_func @property def registry(self) -> CacheableRegistry: """Registry of type `CacheableRegistry`.""" return self._registry @property def disable_whitelist(self) -> bool: """Whether to disable whitelist.""" return self._disable_whitelist @property def disable_machinery(self) -> bool: """Whether to disable machinery.""" return self._disable_machinery @property def clear_cache(self) -> bool: """Whether to clear global cache when entering or local cache when disabling caching.""" return self._clear_cache @property def silence_warnings(self) -> bool: """Whether to silence warnings.""" return self._silence_warnings @property def rule(self) -> tp.Optional[CARule]: """Rule.""" return self._rule @property def init_settings(self) -> tp.Kwargs: """Initial caching settings.""" return self._init_settings @property def init_setup_settings(self) -> tp.Dict[int, dict]: """Initial setup settings.""" return self._init_setup_settings def __enter__(self) -> tp.Self: if self.query_like is None: from vectorbtpro._settings import settings caching_cfg = settings["caching"] self._init_settings = dict( disable=caching_cfg["disable"], disable_whitelist=caching_cfg["disable_whitelist"], disable_machinery=caching_cfg["disable_machinery"], ) caching_cfg["disable"] = True caching_cfg["disable_whitelist"] = self.disable_whitelist caching_cfg["disable_machinery"] = self.disable_machinery if self.clear_cache: clear_cache() else: def _enforce_func(setup): if self.disable_machinery: setup.deactivate() if self.disable_whitelist: setup.disable_whitelist() setup.disable_caching(clear_cache=self.clear_cache) query = CAQuery.parse(self.query_like, use_base_cls=self.use_base_cls) rule = CARule( query, _enforce_func, kind=self.kind, exclude=self.exclude, filter_func=self.filter_func, ) self._rule = rule self.registry.register_rule(rule) init_setup_settings = dict() for setup_hash, setup in self.registry.setups.items(): init_setup_settings[setup_hash] = dict( active=setup.active, whitelist=setup.whitelist, use_cache=setup.use_cache, ) rule.enforce(setup) self._init_setup_settings = init_setup_settings return self def __exit__(self, *args) -> None: if self.query_like is None: from vectorbtpro._settings import settings caching_cfg = settings["caching"] caching_cfg["disable"] = self.init_settings["disable"] caching_cfg["disable_whitelist"] = self.init_settings["disable_whitelist"] caching_cfg["disable_machinery"] = self.init_settings["disable_machinery"] else: self.registry.deregister_rule(self.rule) for setup_hash, setup_settings in self.init_setup_settings.items(): if setup_hash in self.registry.setups: setup = self.registry.setups[setup_hash] if self.disable_machinery and setup_settings["active"]: setup.activate() if self.disable_whitelist and setup_settings["whitelist"]: setup.enable_whitelist() if setup_settings["use_cache"]: setup.enable_caching(silence_warnings=self.silence_warnings) def with_caching_disabled(*args, **caching_disabled_kwargs) -> tp.Callable: """Decorator to run a function with `CachingDisabled`.""" def decorator(func: tp.Callable) -> tp.Callable: @wraps(func) def wrapper(*args, **kwargs) -> tp.Any: with CachingDisabled(**caching_disabled_kwargs): return func(*args, **kwargs) return wrapper if len(args) == 0: return decorator elif len(args) == 1: return decorator(args[0]) raise ValueError("Either function or keyword arguments must be passed") class CachingEnabled(Base): """Context manager to enable caching.""" def __init__( self, query_like: tp.Optional[tp.Any] = None, use_base_cls: bool = True, kind: tp.Optional[tp.MaybeIterable[str]] = None, exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None, filter_func: tp.Optional[tp.Callable] = None, registry: CacheableRegistry = ca_reg, enable_whitelist: bool = True, enable_machinery: bool = True, clear_cache: bool = True, silence_warnings: bool = False, ) -> None: self._query_like = query_like self._use_base_cls = use_base_cls self._kind = kind self._exclude = exclude self._filter_func = filter_func self._registry = registry self._enable_whitelist = enable_whitelist self._enable_machinery = enable_machinery self._clear_cache = clear_cache self._silence_warnings = silence_warnings self._rule = None self._init_settings = None self._init_setup_settings = None @property def query_like(self) -> tp.Optional[tp.Any]: """See `CAQuery.parse`.""" return self._query_like @property def use_base_cls(self) -> bool: """See `CAQuery.parse`.""" return self._use_base_cls @property def kind(self) -> tp.Optional[tp.MaybeIterable[str]]: """See `CARule.kind`.""" return self._kind @property def exclude(self) -> tp.Optional[tp.MaybeIterable["CABaseSetup"]]: """See `CARule.exclude`.""" return self._exclude @property def filter_func(self) -> tp.Optional[tp.Callable]: """See `CARule.filter_func`.""" return self._filter_func @property def registry(self) -> CacheableRegistry: """Registry of type `CacheableRegistry`.""" return self._registry @property def enable_whitelist(self) -> bool: """Whether to enable whitelist.""" return self._enable_whitelist @property def enable_machinery(self) -> bool: """Whether to enable machinery.""" return self._enable_machinery @property def clear_cache(self) -> bool: """Whether to clear global cache when exiting or local cache when disabling caching.""" return self._clear_cache @property def silence_warnings(self) -> bool: """Whether to silence warnings.""" return self._silence_warnings @property def rule(self) -> tp.Optional[CARule]: """Rule.""" return self._rule @property def init_settings(self) -> tp.Kwargs: """Initial caching settings.""" return self._init_settings @property def init_setup_settings(self) -> tp.Dict[int, dict]: """Initial setup settings.""" return self._init_setup_settings def __enter__(self) -> tp.Self: if self.query_like is None: from vectorbtpro._settings import settings caching_cfg = settings["caching"] self._init_settings = dict( disable=caching_cfg["disable"], disable_whitelist=caching_cfg["disable_whitelist"], disable_machinery=caching_cfg["disable_machinery"], ) caching_cfg["disable"] = False caching_cfg["disable_whitelist"] = not self.enable_whitelist caching_cfg["disable_machinery"] = not self.enable_machinery else: def _enforce_func(setup): if self.enable_machinery: setup.activate() if self.enable_whitelist: setup.enable_whitelist() setup.enable_caching(silence_warnings=self.silence_warnings) query = CAQuery.parse(self.query_like, use_base_cls=self.use_base_cls) rule = CARule( query, _enforce_func, kind=self.kind, exclude=self.exclude, filter_func=self.filter_func, ) self._rule = rule self.registry.register_rule(rule) init_setup_settings = dict() for setup_hash, setup in self.registry.setups.items(): init_setup_settings[setup_hash] = dict( active=setup.active, whitelist=setup.whitelist, use_cache=setup.use_cache, ) rule.enforce(setup) self._init_setup_settings = init_setup_settings return self def __exit__(self, *args) -> None: if self.query_like is None: from vectorbtpro._settings import settings caching_cfg = settings["caching"] caching_cfg["disable"] = self.init_settings["disable"] caching_cfg["disable_whitelist"] = self.init_settings["disable_whitelist"] caching_cfg["disable_machinery"] = self.init_settings["disable_machinery"] if self.clear_cache: clear_cache() else: self.registry.deregister_rule(self.rule) for setup_hash, setup_settings in self.init_setup_settings.items(): if setup_hash in self.registry.setups: setup = self.registry.setups[setup_hash] if self.enable_machinery and not setup_settings["active"]: setup.deactivate() if self.enable_whitelist and not setup_settings["whitelist"]: setup.disable_whitelist() if not setup_settings["use_cache"]: setup.disable_caching(clear_cache=self.clear_cache) def with_caching_enabled(*args, **caching_enabled_kwargs) -> tp.Callable: """Decorator to run a function with `CachingEnabled`.""" def decorator(func: tp.Callable) -> tp.Callable: @wraps(func) def wrapper(*args, **kwargs) -> tp.Any: with CachingEnabled(**caching_enabled_kwargs): return func(*args, **kwargs) return wrapper if len(args) == 0: return decorator elif len(args) == 1: return decorator(args[0]) raise ValueError("Either function or keyword arguments must be passed") # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Global registry for chunkables.""" from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.base import Base from vectorbtpro.utils.chunking import chunked, resolve_chunked, resolve_chunked_option from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.template import RepEval __all__ = [ "ChunkableRegistry", "ch_reg", "register_chunkable", ] @define class ChunkedSetup(DefineMixin): """Class that represents a chunkable setup. !!! note Hashed solely by `setup_id`.""" setup_id: tp.Hashable = define.field() """Setup id.""" func: tp.Callable = define.field() """Chunkable function.""" options: tp.DictLike = define.field(default=None) """Options dictionary.""" tags: tp.SetLike = define.field(default=None) """Set of tags.""" @staticmethod def get_hash(setup_id: tp.Hashable) -> int: return hash((setup_id,)) @property def hash_key(self) -> tuple: return (self.setup_id,) class ChunkableRegistry(Base): """Class for registering chunkable functions.""" def __init__(self) -> None: self._setups = {} @property def setups(self) -> tp.Dict[tp.Hashable, ChunkedSetup]: """Dict of registered `ChunkedSetup` instances by `ChunkedSetup.setup_id`.""" return self._setups def register( self, func: tp.Callable, setup_id: tp.Optional[tp.Hashable] = None, options: tp.DictLike = None, tags: tp.SetLike = None, ) -> None: """Register a new setup.""" if setup_id is None: setup_id = func.__module__ + "." + func.__name__ setup = ChunkedSetup(setup_id=setup_id, func=func, options=options, tags=tags) self.setups[setup_id] = setup def match_setups(self, expression: tp.Optional[str] = None, context: tp.KwargsLike = None) -> tp.Set[ChunkedSetup]: """Match setups against an expression with each setup being a context.""" matched_setups = set() for setup in self.setups.values(): if expression is None: result = True else: result = RepEval(expression).substitute(context=merge_dicts(setup.asdict(), context)) checks.assert_instance_of(result, bool) if result: matched_setups.add(setup) return matched_setups def get_setup(self, setup_id_or_func: tp.Union[tp.Hashable, tp.Callable]) -> tp.Optional[ChunkedSetup]: """Get setup by its id or function. `setup_id_or_func` can be an identifier or a function. If it's a function, will build the identifier using its module and name.""" if hasattr(setup_id_or_func, "py_func"): nb_setup_id = setup_id_or_func.__module__ + "." + setup_id_or_func.__name__ if nb_setup_id in self.setups: setup_id = nb_setup_id else: setup_id = setup_id_or_func.py_func.__module__ + "." + setup_id_or_func.py_func.__name__ elif callable(setup_id_or_func): setup_id = setup_id_or_func.__module__ + "." + setup_id_or_func.__name__ else: setup_id = setup_id_or_func if setup_id not in self.setups: return None return self.setups[setup_id] def decorate( self, setup_id_or_func: tp.Union[tp.Hashable, tp.Callable], target_func: tp.Optional[tp.Callable] = None, **kwargs, ) -> tp.Callable: """Decorate the setup's function using the `vectorbtpro.utils.chunking.chunked` decorator. Finds setup using `ChunkableRegistry.get_setup`. Merges setup's options with `options`. Specify `target_func` to apply the found setup on another function.""" setup = self.get_setup(setup_id_or_func) if setup is None: raise KeyError(f"Setup for {setup_id_or_func} not registered") if target_func is not None: func = target_func elif callable(setup_id_or_func): func = setup_id_or_func else: func = setup.func return chunked(func, **merge_dicts(setup.options, kwargs)) def resolve_option( self, setup_id_or_func: tp.Union[tp.Hashable, tp.Callable], option: tp.ChunkedOption, target_func: tp.Optional[tp.Callable] = None, **kwargs, ) -> tp.Callable: """Same as `ChunkableRegistry.decorate` but using `vectorbtpro.utils.chunking.resolve_chunked`.""" setup = self.get_setup(setup_id_or_func) if setup is None: if callable(setup_id_or_func): option = resolve_chunked_option(option=option) if option is None: return setup_id_or_func raise KeyError(f"Setup for {setup_id_or_func} not registered") if target_func is not None: func = target_func elif callable(setup_id_or_func): func = setup_id_or_func else: func = setup.func return resolve_chunked(func, option=option, **merge_dicts(setup.options, kwargs)) ch_reg = ChunkableRegistry() """Default registry of type `ChunkableRegistry`.""" def register_chunkable( func: tp.Optional[tp.Callable] = None, setup_id: tp.Optional[tp.Hashable] = None, registry: ChunkableRegistry = ch_reg, tags: tp.SetLike = None, return_wrapped: bool = False, **options, ) -> tp.Callable: """Register a new chunkable function. If `return_wrapped` is True, wraps with the `vectorbtpro.utils.chunking.chunked` decorator. Otherwise, leaves the function as-is (preferred). Options are merged in the following order: * `options` in `vectorbtpro._settings.chunking` * `setup_options.{setup_id}` in `vectorbtpro._settings.chunking` * `options` * `override_options` in `vectorbtpro._settings.chunking` * `override_setup_options.{setup_id}` in `vectorbtpro._settings.chunking` !!! note Calling the `register_chunkable` decorator before (or below) the `vectorbtpro.registries.jit_registry.register_jitted` decorator with `return_wrapped` set to True won't work. Doing the same after (or above) `vectorbtpro.registries.jit_registry.register_jitted` will work for calling the function from Python but not from Numba. Generally, avoid wrapping right away and use `ChunkableRegistry.decorate` to perform decoration.""" def decorator(_func: tp.Callable) -> tp.Callable: nonlocal setup_id, options from vectorbtpro._settings import settings chunking_cfg = settings["chunking"] if setup_id is None: setup_id = _func.__module__ + "." + _func.__name__ options = merge_dicts( chunking_cfg.get("options", None), chunking_cfg.get("setup_options", {}).get(setup_id, None), options, chunking_cfg.get("override_options", None), chunking_cfg.get("override_setup_options", {}).get(setup_id, None), ) registry.register(func=_func, setup_id=setup_id, options=options, tags=tags) if return_wrapped: return chunked(_func, **options) return _func if func is None: return decorator return decorator(func) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Global registry for jittables. Jitting is a process of just-in-time compiling functions to make their execution faster. A jitter is a decorator that wraps a regular Python function and returns the decorated function. Depending upon a jitter, this decorated function has the same or at least a similar signature to the function that has been decorated. Jitters take various jitter-specific options to change the behavior of execution; that is, a single regular Python function can be decorated by multiple jitter instances (for example, one jitter for decorating a function with `numba.jit` and another jitter for doing the same with `parallel=True` flag). In addition to jitters, vectorbt introduces the concept of tasks. One task can be executed by multiple jitter types (such as NumPy, Numba, and JAX). For example, one can create a task that converts price into returns and implements it using NumPy and Numba. Those implementations are registered by `JITRegistry` as `JitableSetup` instances, are stored in `JITRegistry.jitable_setups`, and can be uniquely identified by the task id and jitter type. Note that `JitableSetup` instances contain only information on how to decorate a function. The decorated function itself and the jitter that has been used are registered as a `JittedSetup` instance and stored in `JITRegistry.jitted_setups`. It acts as a cache to quickly retrieve an already decorated function and to avoid recompilation. Let's implement a task that takes a sum over an array using both NumPy and Numba: ```pycon >>> from vectorbtpro import * >>> @vbt.register_jitted(task_id_or_func='sum') ... def sum_np(a): ... return a.sum() >>> @vbt.register_jitted(task_id_or_func='sum') ... def sum_nb(a): ... out = 0. ... for i in range(a.shape[0]): ... out += a[i] ... return out ``` We can see that two new jitable setups were registered: ```pycon >>> vbt.jit_reg.jitable_setups['sum'] {'np': JitableSetup(task_id='sum', jitter_id='np', py_func=, jitter_kwargs={}, tags=None), 'nb': JitableSetup(task_id='sum', jitter_id='nb', py_func=, jitter_kwargs={}, tags=None)} ``` Moreover, two jitted setups were registered for our decorated functions: ```pycon >>> from vectorbtpro.registries.jit_registry import JitableSetup >>> hash_np = JitableSetup.get_hash('sum', 'np') >>> vbt.jit_reg.jitted_setups[hash_np] {3527539: JittedSetup(jitter=, jitted_func=)} >>> hash_nb = JitableSetup.get_hash('sum', 'nb') >>> vbt.jit_reg.jitted_setups[hash_nb] {6326224984503844995: JittedSetup(jitter=, jitted_func=CPUDispatcher())} ``` These setups contain decorated functions with the options passed during the registration. When we call `JITRegistry.resolve` without any additional keyword arguments, `JITRegistry` returns exactly these functions: ```pycon >>> jitted_func = vbt.jit_reg.resolve('sum', jitter='nb') >>> jitted_func CPUDispatcher() >>> jitted_func.targetoptions {'nopython': True, 'nogil': True, 'parallel': False, 'boundscheck': False} ``` Once we pass any other option, the Python function will be redecorated, and another `JittedOption` instance will be registered: ```pycon >>> jitted_func = vbt.jit_reg.resolve('sum', jitter='nb', nopython=False) >>> jitted_func CPUDispatcher() >>> jitted_func.targetoptions {'nopython': False, 'nogil': True, 'parallel': False, 'boundscheck': False} >>> vbt.jit_reg.jitted_setups[hash_nb] {6326224984503844995: JittedSetup(jitter=, jitted_func=CPUDispatcher()), -2979374923679407948: JittedSetup(jitter=, jitted_func=CPUDispatcher())} ``` ## Templates Templates can be used to, based on the current context, dynamically select the jitter or keyword arguments for jitting. For example, let's pick the NumPy jitter over any other jitter if there are more than two of them for a given task: ```pycon >>> vbt.jit_reg.resolve('sum', jitter=vbt.RepEval("'nb' if 'nb' in task_setups else None")) CPUDispatcher() ``` ## Disabling In the case we want to disable jitting, we can simply pass `disable=True` to `JITRegistry.resolve`: ```pycon >>> py_func = vbt.jit_reg.resolve('sum', jitter='nb', disable=True) >>> py_func ``` We can also disable jitting globally: ```pycon >>> vbt.settings.jitting['disable'] = True >>> vbt.jit_reg.resolve('sum', jitter='nb') ``` !!! hint If we don't plan to use any additional options and we have only one jitter registered per task, we can also disable resolution to increase performance. !!! warning Disabling jitting globally only applies to functions resolved using `JITRegistry.resolve`. Any decorated function that is being called directly will be executed as usual. ## Jitted option Since most functions that call other jitted functions in vectorbt have a `jitted` argument, you can pass `jitted` as a dictionary with options, as a string denoting the jitter, or False to disable jitting (see `vectorbtpro.utils.jitting.resolve_jitted_option`): ```pycon >>> def sum_arr(arr, jitted=None): ... func = vbt.jit_reg.resolve_option('sum', jitted) ... return func(arr) >>> arr = np.random.uniform(size=1000000) >>> %timeit sum_arr(arr, jitted='np') 319 µs ± 3.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) >>> %timeit sum_arr(arr, jitted='nb') 1.09 ms ± 4.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) >>> %timeit sum_arr(arr, jitted=dict(jitter='nb', disable=True)) 133 ms ± 2.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` !!! hint A good rule of thumb is: whenever a caller function accepts a `jitted` argument, the jitted functions it calls are most probably resolved using `JITRegistry.resolve_option`. ## Changing options upon registration Options are usually specified upon registration using `register_jitted`: ```pycon >>> from numba import prange >>> @vbt.register_jitted(parallel=True, tags={'can_parallel'}) ... def sum_parallel_nb(a): ... out = np.empty(a.shape[1]) ... for col in prange(a.shape[1]): ... total = 0. ... for i in range(a.shape[0]): ... total += a[i, col] ... out[col] = total ... return out >>> sum_parallel_nb.targetoptions {'nopython': True, 'nogil': True, 'parallel': True, 'boundscheck': False} ``` But what if we wanted to change the registration options of vectorbt's own jitable functions, such as `vectorbtpro.generic.nb.base.diff_nb`? For example, let's disable caching for all Numba functions. ```pycon >>> vbt.settings.jitting.jitters['nb']['override_options'] = dict(cache=False) ``` Since all functions have already been registered, the above statement has no effect: ```pycon >>> vbt.jit_reg.jitable_setups['vectorbtpro.generic.nb.base.diff_nb']['nb'].jitter_kwargs {'cache': True} ``` In order for them to be applied, we need to save the settings to a file and load them before all functions are imported: ```pycon >>> vbt.settings.save('my_settings') ``` Let's restart the runtime and instruct vectorbt to load the file with settings before anything else: ```pycon >>> import os >>> os.environ['VBT_SETTINGS_PATH'] = "my_settings" >>> from vectorbtpro import * >>> vbt.jit_reg.jitable_setups['vectorbtpro.generic.nb.base.diff_nb']['nb'].jitter_kwargs {'cache': False} ``` We can also change the registration options for some specific tasks, and even replace Python functions. For example, we can change the implementation in the deepest places of the core. Let's change the default `ddof` from 0 to 1 in `vectorbtpro.generic.nb.base.nanstd_1d_nb` and disable caching with Numba: ```pycon >>> vbt.nb.nanstd_1d_nb(np.array([1, 2, 3])) 0.816496580927726 >>> def new_nanstd_1d_nb(arr, ddof=1): ... return np.sqrt(vbt.nb.nanvar_1d_nb(arr, ddof=ddof)) >>> vbt.settings.jitting.jitters['nb']['tasks']['vectorbtpro.generic.nb.base.nanstd_1d_nb'] = dict( ... replace_py_func=new_nanstd_1d_nb, ... override_options=dict( ... cache=False ... ) ... ) >>> vbt.settings.save('my_settings') ``` After restarting the runtime: ```pycon >>> import os >>> os.environ['VBT_SETTINGS_PATH'] = "my_settings" >>> vbt.nb.nanstd_1d_nb(np.array([1, 2, 3])) 1.0 ``` !!! note All of the above examples require saving the setting to a file, restarting the runtime, setting the path to the file to an environment variable, and only then importing vectorbtpro. ## Changing options upon resolution Another approach but without the need to restart the runtime is by changing the options upon resolution using `JITRegistry.resolve_option`: ```pycon >>> # On specific Numba function >>> vbt.settings.jitting.jitters['nb']['tasks']['vectorbtpro.generic.nb.base.diff_nb'] = dict( ... resolve_kwargs=dict( ... nogil=False ... ) ... ) >>> # disabled >>> vbt.jit_reg.resolve('vectorbtpro.generic.nb.base.diff_nb', jitter='nb').targetoptions {'nopython': True, 'nogil': False, 'parallel': False, 'boundscheck': False} >>> # still enabled >>> vbt.jit_reg.resolve('sum', jitter='nb').targetoptions {'nopython': True, 'nogil': True, 'parallel': False, 'boundscheck': False} >>> # On each Numba function >>> vbt.settings.jitting.jitters['nb']['resolve_kwargs'] = dict(nogil=False) >>> # disabled >>> vbt.jit_reg.resolve('vectorbtpro.generic.nb.base.diff_nb', jitter='nb').targetoptions {'nopython': True, 'nogil': False, 'parallel': False, 'boundscheck': False} >>> # disabled >>> vbt.jit_reg.resolve('sum', jitter='nb').targetoptions {'nopython': True, 'nogil': False, 'parallel': False, 'boundscheck': False} ``` ## Building custom jitters Let's build a custom jitter on top of `vectorbtpro.utils.jitting.NumbaJitter` that converts any argument that contains a Pandas object to a 2-dimensional NumPy array prior to decoration: ```pycon >>> from functools import wraps >>> from vectorbtpro.utils.jitting import NumbaJitter >>> class SafeNumbaJitter(NumbaJitter): ... def decorate(self, py_func, tags=None): ... if self.wrapping_disabled: ... return py_func ... ... @wraps(py_func) ... def wrapper(*args, **kwargs): ... new_args = () ... for arg in args: ... if isinstance(arg, pd.Series): ... arg = np.expand_dims(arg.values, 1) ... elif isinstance(arg, pd.DataFrame): ... arg = arg.values ... new_args += (arg,) ... new_kwargs = dict() ... for k, v in kwargs.items(): ... if isinstance(v, pd.Series): ... v = np.expand_dims(v.values, 1) ... elif isinstance(v, pd.DataFrame): ... v = v.values ... new_kwargs[k] = v ... return NumbaJitter.decorate(self, py_func, tags=tags)(*new_args, **new_kwargs) ... return wrapper ``` After we have defined our jitter class, we need to register it globally: ```pycon >>> vbt.settings.jitting.jitters['safe_nb'] = dict(cls=SafeNumbaJitter) ``` Finally, we can execute any Numba function by specifying our new jitter: ```pycon >>> func = vbt.jit_reg.resolve( ... task_id_or_func=vbt.generic.nb.diff_nb, ... jitter='safe_nb', ... allow_new=True ... ) >>> func(pd.DataFrame([[1, 2], [3, 4]])) array([[nan, nan], [ 2., 2.]]) ``` Whereas executing the same func using the vanilla Numba jitter causes an error: ```pycon >>> func = vbt.jit_reg.resolve(task_id_or_func=vbt.generic.nb.diff_nb) >>> func(pd.DataFrame([[1, 2], [3, 4]])) Failed in nopython mode pipeline (step: nopython frontend) non-precise type pyobject ``` !!! note Make sure to pass a function as `task_id_or_func` if the jitted function hasn't been registered yet. This jitter cannot be used for decorating Numba functions that should be called from other Numba functions since the convertion operation is done using Python. """ from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import merge_dicts, atomic_dict from vectorbtpro.utils.jitting import ( Jitter, resolve_jitted_kwargs, resolve_jitter_type, resolve_jitter, get_id_of_jitter_type, get_func_suffix, ) from vectorbtpro.utils.template import RepEval, substitute_templates, CustomTemplate __all__ = [ "JITRegistry", "jit_reg", "register_jitted", ] def get_func_full_name(func: tp.Callable) -> str: """Get full name of the func to be used as task id.""" return func.__module__ + "." + func.__name__ @define class JitableSetup(DefineMixin): """Class that represents a jitable setup. !!! note Hashed solely by `task_id` and `jitter_id`.""" task_id: tp.Hashable = define.field() """Task id.""" jitter_id: tp.Hashable = define.field() """Jitter id.""" py_func: tp.Callable = define.field() """Python function to be jitted.""" jitter_kwargs: tp.KwargsLike = define.field(default=None) """Keyword arguments passed to `vectorbtpro.utils.jitting.resolve_jitter`.""" tags: tp.SetLike = define.field(default=None) """Set of tags.""" @staticmethod def get_hash(task_id: tp.Hashable, jitter_id: tp.Hashable) -> int: return hash((task_id, jitter_id)) @property def hash_key(self) -> tuple: return (self.task_id, self.jitter_id) @define class JittedSetup(DefineMixin): """Class that represents a jitted setup. !!! note Hashed solely by sorted config of `jitter`. That is, two jitters with the same config will yield the same hash and the function won't be re-decorated.""" jitter: Jitter = define.field() """Jitter that decorated the function.""" jitted_func: tp.Callable = define.field() """Decorated function.""" @staticmethod def get_hash(jitter: Jitter) -> int: return hash(tuple(sorted(jitter.config.items()))) @property def hash_key(self) -> tuple: return tuple(sorted(self.jitter.config.items())) class JITRegistry(Base): """Class for registering jitted functions.""" def __init__(self) -> None: self._jitable_setups = {} self._jitted_setups = {} @property def jitable_setups(self) -> tp.Dict[tp.Hashable, tp.Dict[tp.Hashable, JitableSetup]]: """Dict of registered `JitableSetup` instances by `task_id` and `jitter_id`.""" return self._jitable_setups @property def jitted_setups(self) -> tp.Dict[int, tp.Dict[int, JittedSetup]]: """Nested dict of registered `JittedSetup` instances by hash of their `JitableSetup` instance.""" return self._jitted_setups def register_jitable_setup( self, task_id: tp.Hashable, jitter_id: tp.Hashable, py_func: tp.Callable, jitter_kwargs: tp.KwargsLike = None, tags: tp.Optional[set] = None, ) -> JitableSetup: """Register a jitable setup.""" jitable_setup = JitableSetup( task_id=task_id, jitter_id=jitter_id, py_func=py_func, jitter_kwargs=jitter_kwargs, tags=tags, ) if task_id not in self.jitable_setups: self.jitable_setups[task_id] = dict() if jitter_id not in self.jitable_setups[task_id]: self.jitable_setups[task_id][jitter_id] = jitable_setup return jitable_setup def register_jitted_setup( self, jitable_setup: JitableSetup, jitter: Jitter, jitted_func: tp.Callable, ) -> JittedSetup: """Register a jitted setup.""" jitable_setup_hash = hash(jitable_setup) jitted_setup = JittedSetup(jitter=jitter, jitted_func=jitted_func) jitted_setup_hash = hash(jitted_setup) if jitable_setup_hash not in self.jitted_setups: self.jitted_setups[jitable_setup_hash] = dict() if jitted_setup_hash not in self.jitted_setups[jitable_setup_hash]: self.jitted_setups[jitable_setup_hash][jitted_setup_hash] = jitted_setup return jitted_setup def decorate_and_register( self, task_id: tp.Hashable, py_func: tp.Callable, jitter: tp.Optional[tp.JitterLike] = None, jitter_kwargs: tp.KwargsLike = None, tags: tp.Optional[set] = None, ): """Decorate a jitable function and register both jitable and jitted setups.""" if jitter_kwargs is None: jitter_kwargs = {} jitter = resolve_jitter(jitter=jitter, py_func=py_func, **jitter_kwargs) jitter_id = get_id_of_jitter_type(type(jitter)) if jitter_id is None: raise ValueError("Jitter id cannot be None: is jitter registered globally?") jitable_setup = self.register_jitable_setup(task_id, jitter_id, py_func, jitter_kwargs=jitter_kwargs, tags=tags) jitted_func = jitter.decorate(py_func, tags=tags) self.register_jitted_setup(jitable_setup, jitter, jitted_func) return jitted_func def match_jitable_setups( self, expression: tp.Optional[str] = None, context: tp.KwargsLike = None, ) -> tp.Set[JitableSetup]: """Match jitable setups against an expression with each setup being a context.""" matched_setups = set() for setups_by_jitter_id in self.jitable_setups.values(): for setup in setups_by_jitter_id.values(): if expression is None: result = True else: result = RepEval(expression).substitute(context=merge_dicts(setup.asdict(), context)) checks.assert_instance_of(result, bool) if result: matched_setups.add(setup) return matched_setups def match_jitted_setups( self, jitable_setup: JitableSetup, expression: tp.Optional[str] = None, context: tp.KwargsLike = None, ) -> tp.Set[JittedSetup]: """Match jitted setups of a jitable setup against an expression with each setup a context.""" matched_setups = set() for setup in self.jitted_setups[hash(jitable_setup)].values(): if expression is None: result = True else: result = RepEval(expression).substitute(context=merge_dicts(setup.asdict(), context)) checks.assert_instance_of(result, bool) if result: matched_setups.add(setup) return matched_setups def resolve( self, task_id_or_func: tp.Union[tp.Hashable, tp.Callable], jitter: tp.Optional[tp.Union[tp.JitterLike, CustomTemplate]] = None, disable: tp.Optional[tp.Union[bool, CustomTemplate]] = None, disable_resolution: tp.Optional[bool] = None, allow_new: tp.Optional[bool] = None, register_new: tp.Optional[bool] = None, return_missing_task: bool = False, template_context: tp.KwargsLike = None, tags: tp.Optional[set] = None, **jitter_kwargs, ) -> tp.Union[tp.Hashable, tp.Callable]: """Resolve jitted function for the given task id. For details on the format of `task_id_or_func`, see `register_jitted`. Jitter keyword arguments are merged in the following order: * `jitable_setup.jitter_kwargs` * `jitter.your_jitter.resolve_kwargs` in `vectorbtpro._settings.jitting` * `jitter.your_jitter.tasks.your_task.resolve_kwargs` in `vectorbtpro._settings.jitting` * `jitter_kwargs` Templates are substituted in `jitter`, `disable`, and `jitter_kwargs`. Set `disable` to True to return the Python function without decoration. If `disable_resolution` is enabled globally, `task_id_or_func` is returned unchanged. !!! note `disable` is only being used by `JITRegistry`, not `vectorbtpro.utils.jitting`. !!! note If there are more than one jitted setups registered for a single task id, make sure to provide a jitter. If no jitted setup of type `JittedSetup` was found and `allow_new` is True, decorates and returns the function supplied as `task_id_or_func` (otherwise throws an error). Set `return_missing_task` to True to return `task_id_or_func` if it cannot be found in `JITRegistry.jitable_setups`. """ from vectorbtpro._settings import settings jitting_cfg = settings["jitting"] if disable_resolution is None: disable_resolution = jitting_cfg["disable_resolution"] if disable_resolution: return task_id_or_func if allow_new is None: allow_new = jitting_cfg["allow_new"] if register_new is None: register_new = jitting_cfg["register_new"] if hasattr(task_id_or_func, "py_func"): py_func = task_id_or_func.py_func task_id = get_func_full_name(py_func) elif callable(task_id_or_func): py_func = task_id_or_func task_id = get_func_full_name(py_func) else: py_func = None task_id = task_id_or_func if task_id not in self.jitable_setups: if not allow_new: if return_missing_task: return task_id_or_func raise KeyError(f"Task id '{task_id}' not registered") task_setups = self.jitable_setups.get(task_id, dict()) template_context = merge_dicts( jitting_cfg["template_context"], template_context, dict(task_id=task_id, py_func=py_func, task_setups=atomic_dict(task_setups)), ) jitter = substitute_templates(jitter, template_context, eval_id="jitter") if jitter is None and py_func is not None: jitter = get_func_suffix(py_func) if jitter is None: if len(task_setups) > 1: raise ValueError( f"There are multiple registered setups for task id '{task_id}'. Please specify the jitter." ) elif len(task_setups) == 0: raise ValueError(f"There are no registered setups for task id '{task_id}'") jitable_setup = list(task_setups.values())[0] jitter = jitable_setup.jitter_id jitter_id = jitable_setup.jitter_id else: jitter_type = resolve_jitter_type(jitter=jitter) jitter_id = get_id_of_jitter_type(jitter_type) if jitter_id not in task_setups: if not allow_new: raise KeyError(f"Jitable setup with task id '{task_id}' and jitter id '{jitter_id}' not registered") jitable_setup = None else: jitable_setup = task_setups[jitter_id] if jitter_id is None: raise ValueError("Jitter id cannot be None: is jitter registered globally?") if jitable_setup is None and py_func is None: raise ValueError(f"Unable to find Python function for task id '{task_id}' and jitter id '{jitter_id}'") template_context = merge_dicts( template_context, dict(jitter_id=jitter_id, jitter=jitter, jitable_setup=jitable_setup), ) disable = substitute_templates(disable, template_context, eval_id="disable") if disable is None: disable = jitting_cfg["disable"] if disable: if jitable_setup is None: return py_func return jitable_setup.py_func if not isinstance(jitter, Jitter): jitter_cfg = jitting_cfg["jitters"].get(jitter_id, {}) setup_cfg = jitter_cfg.get("tasks", {}).get(task_id, {}) jitter_kwargs = merge_dicts( jitable_setup.jitter_kwargs if jitable_setup is not None else None, jitter_cfg.get("resolve_kwargs", None), setup_cfg.get("resolve_kwargs", None), jitter_kwargs, ) jitter_kwargs = substitute_templates(jitter_kwargs, template_context, eval_id="jitter_kwargs") jitter = resolve_jitter(jitter=jitter, **jitter_kwargs) if jitable_setup is not None: jitable_hash = hash(jitable_setup) jitted_hash = JittedSetup.get_hash(jitter) if jitable_hash in self.jitted_setups and jitted_hash in self.jitted_setups[jitable_hash]: return self.jitted_setups[jitable_hash][jitted_hash].jitted_func else: if register_new: return self.decorate_and_register( task_id=task_id, py_func=py_func, jitter=jitter, jitter_kwargs=jitter_kwargs, tags=tags, ) return jitter.decorate(py_func, tags=tags) jitted_func = jitter.decorate(jitable_setup.py_func, tags=jitable_setup.tags) self.register_jitted_setup(jitable_setup, jitter, jitted_func) return jitted_func def resolve_option( self, task_id: tp.Union[tp.Hashable, tp.Callable], option: tp.JittedOption, **kwargs, ) -> tp.Union[tp.Hashable, tp.Callable]: """Resolve `option` using `vectorbtpro.utils.jitting.resolve_jitted_option` and call `JITRegistry.resolve`.""" kwargs = resolve_jitted_kwargs(option=option, **kwargs) if kwargs is None: kwargs = dict(disable=True) return self.resolve(task_id, **kwargs) jit_reg = JITRegistry() """Default registry of type `JITRegistry`.""" def register_jitted( py_func: tp.Optional[tp.Callable] = None, task_id_or_func: tp.Optional[tp.Union[tp.Hashable, tp.Callable]] = None, registry: JITRegistry = jit_reg, tags: tp.Optional[set] = None, **options, ) -> tp.Callable: """Decorate and register a jitable function using `JITRegistry.decorate_and_register`. If `task_id_or_func` is a callable, gets replaced by the callable's module name and function name. Additionally, the function name may contain a suffix pointing at the jitter (such as `_nb`). Options are merged in the following order: * `jitters.{jitter_id}.options` in `vectorbtpro._settings.jitting` * `jitters.{jitter_id}.tasks.{task_id}.options` in `vectorbtpro._settings.jitting` * `options` * `jitters.{jitter_id}.override_options` in `vectorbtpro._settings.jitting` * `jitters.{jitter_id}.tasks.{task_id}.override_options` in `vectorbtpro._settings.jitting` `py_func` can also be overridden using `jitters.your_jitter.tasks.your_task.replace_py_func` in `vectorbtpro._settings.jitting`.""" def decorator(_py_func: tp.Callable) -> tp.Callable: nonlocal options from vectorbtpro._settings import settings jitting_cfg = settings["jitting"] if task_id_or_func is None: task_id = get_func_full_name(_py_func) elif hasattr(task_id_or_func, "py_func"): task_id = get_func_full_name(task_id_or_func.py_func) elif callable(task_id_or_func): task_id = get_func_full_name(task_id_or_func) else: task_id = task_id_or_func jitter = options.pop("jitter", None) jitter_type = resolve_jitter_type(jitter=jitter, py_func=_py_func) jitter_id = get_id_of_jitter_type(jitter_type) jitter_cfg = jitting_cfg["jitters"].get(jitter_id, {}) setup_cfg = jitter_cfg.get("tasks", {}).get(task_id, {}) options = merge_dicts( jitter_cfg.get("options", None), setup_cfg.get("options", None), options, jitter_cfg.get("override_options", None), setup_cfg.get("override_options", None), ) if setup_cfg.get("replace_py_func", None) is not None: _py_func = setup_cfg["replace_py_func"] if task_id_or_func is None: task_id = get_func_full_name(_py_func) return registry.decorate_and_register( task_id=task_id, py_func=_py_func, jitter=jitter, jitter_kwargs=options, tags=tags, ) if py_func is None: return decorator return decorator(py_func) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Global registry for progress bars.""" import uuid from vectorbtpro import _typing as tp from vectorbtpro.utils.base import Base if tp.TYPE_CHECKING: from vectorbtpro.utils.pbar import ProgressBar as ProgressBarT else: ProgressBarT = "ProgressBar" __all__ = [ "PBarRegistry", "pbar_reg", ] class PBarRegistry(Base): """Class for registering `vectorbtpro.utils.pbar.ProgressBar` instances.""" @classmethod def generate_bar_id(cls) -> tp.Hashable: """Generate a unique bar id.""" return str(uuid.uuid4()) def __init__(self): self._instances = {} @property def instances(self) -> tp.Dict[tp.Hashable, ProgressBarT]: """Dict of registered instances by their bar id.""" return self._instances def register_instance(self, instance: ProgressBarT) -> None: """Register an instance.""" self.instances[instance.bar_id] = instance def deregister_instance(self, instance: ProgressBarT) -> None: """Deregister an instance.""" if instance.bar_id in self.instances: del self.instances[instance.bar_id] def has_conflict(self, instance: ProgressBarT) -> bool: """Return whether there is an (active) instance with the same bar id.""" if instance.bar_id is None: return False for inst in self.instances.values(): if inst is not instance and inst.bar_id == instance.bar_id and inst.active: return True return False def get_last_active_instance(self) -> tp.Optional[ProgressBarT]: """Get the last active instance.""" max_open_time = None last_active_instance = None for inst in self.instances.values(): if inst.active: if max_open_time is None or inst.open_time > max_open_time: max_open_time = inst.open_time last_active_instance = inst return last_active_instance def get_first_pending_instance(self) -> tp.Optional[ProgressBarT]: """Get the first pending instance.""" last_active_instance = self.get_last_active_instance() if last_active_instance is None: return None min_open_time = None first_pending_instance = None for inst in self.instances.values(): if inst.pending and inst.open_time > last_active_instance.open_time: if min_open_time is None or inst.open_time < min_open_time: min_open_time = inst.open_time first_pending_instance = inst return first_pending_instance def get_pending_instance(self, instance: ProgressBarT) -> tp.Optional[ProgressBarT]: """Get the pending instance. If the bar id is not None, searches for the same id in the dictionary.""" if instance.bar_id is not None: for inst in self.instances.values(): if inst is not instance and inst.pending: if inst.bar_id == instance.bar_id: return inst return None last_active_instance = self.get_last_active_instance() if last_active_instance is None: return None min_open_time = None first_pending_instance = None for inst in self.instances.values(): if inst.pending and inst.open_time > last_active_instance.open_time: if min_open_time is None or inst.open_time < min_open_time: min_open_time = inst.open_time first_pending_instance = inst return first_pending_instance def get_parent_instances(self, instance: ProgressBarT) -> tp.List[ProgressBarT]: """Get the (active) parent instances of an instance.""" parent_instances = [] for inst in self.instances.values(): if inst is not instance and inst.active: if inst.open_time < instance.open_time: parent_instances.append(inst) return parent_instances def get_parent_instance(self, instance: ProgressBarT) -> tp.Optional[ProgressBarT]: """Get the (active) parent instance of an instance.""" max_open_time = None parent_instance = None for inst in self.get_parent_instances(instance): if max_open_time is None or inst.open_time > max_open_time: max_open_time = inst.open_time parent_instance = inst return parent_instance def get_child_instances(self, instance: ProgressBarT) -> tp.List[ProgressBarT]: """Get child (active or pending) instances of an instance.""" child_instances = [] for inst in self.instances.values(): if inst is not instance and (inst.active or inst.pending): if inst.open_time > instance.open_time: child_instances.append(inst) return child_instances def clear_instances(self) -> None: """Clear instances.""" self.instances.clear() pbar_reg = PBarRegistry() """Default registry of type `PBarRegistry`.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with returns. Offers common financial risk and performance metrics as found in [empyrical](https://github.com/quantopian/empyrical), an adapter for quantstats, and other features based on returns.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.returns.accessors import * from vectorbtpro.returns.nb import * from vectorbtpro.returns.qs_adapter import * __exclude_from__all__ = [ "enums", ] __import_if_installed__ = dict() __import_if_installed__["qs_adapter"] = "quantstats" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom Pandas accessors for returns. Methods can be accessed as follows: * `ReturnsSRAccessor` -> `pd.Series.vbt.returns.*` * `ReturnsDFAccessor` -> `pd.DataFrame.vbt.returns.*` !!! note The underlying Series/DataFrame must already be a return series. To convert price to returns, use `ReturnsAccessor.from_value`. Grouping is only supported by the methods that accept the `group_by` argument. Accessors do not utilize caching. There are three options to compute returns and get the accessor: ```pycon >>> from vectorbtpro import * >>> price = pd.Series([1.1, 1.2, 1.3, 1.2, 1.1]) >>> # 1. pd.Series.pct_change >>> rets = price.pct_change() >>> ret_acc = rets.vbt.returns(freq='d') >>> # 2. vectorbtpro.generic.accessors.GenericAccessor.to_returns >>> rets = price.vbt.to_returns() >>> ret_acc = rets.vbt.returns(freq='d') >>> # 3. vectorbtpro.returns.accessors.ReturnsAccessor.from_value >>> ret_acc = pd.Series.vbt.returns.from_value(price, freq='d') >>> # vectorbtpro.returns.accessors.ReturnsAccessor.total >>> ret_acc.total() 0.0 ``` The accessors extend `vectorbtpro.generic.accessors`. ```pycon >>> # inherited from GenericAccessor >>> ret_acc.max() 0.09090909090909083 ``` ## Defaults `vectorbtpro.returns.accessors.ReturnsAccessor` accepts `defaults` dictionary where you can pass defaults for arguments used throughout the accessor, such as * `start_value`: The starting value. * `window`: Window length. * `minp`: Minimum number of observations in a window required to have a value. * `ddof`: Delta Degrees of Freedom. * `risk_free`: Constant risk-free return throughout the period. * `levy_alpha`: Scaling relation (Levy stability exponent). * `required_return`: Minimum acceptance return of the investor. * `cutoff`: Decimal representing the percentage cutoff for the bottom percentile of returns. * `periods`: Number of observations for annualization. Can be an integer or "dt_periods". Defaults as well as `bm_returns` and `year_freq` can be set globally using settings: ```pycon >>> benchmark = pd.Series([1.05, 1.1, 1.15, 1.1, 1.05]) >>> bm_returns = benchmark.vbt.to_returns() >>> vbt.settings.returns['bm_returns'] = bm_returns ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `ReturnsAccessor.metrics`. ```pycon >>> ret_acc.stats() Start 0 End 4 Duration 5 days 00:00:00 Total Return [%] 0 Benchmark Return [%] 0 Annualized Return [%] 0 Annualized Volatility [%] 184.643 Sharpe Ratio 0.691185 Calmar Ratio 0 Max Drawdown [%] 15.3846 Omega Ratio 1.08727 Sortino Ratio 1.17805 Skew 0.00151002 Kurtosis -5.94737 Tail Ratio 1.08985 Common Sense Ratio 1.08985 Value at Risk -0.0823718 Alpha 0.78789 Beta 1.83864 dtype: object ``` !!! note `ReturnsAccessor.stats` does not support grouping. ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `ReturnsAccessor.subplots`. `ReturnsAccessor` class has a single subplot based on `ReturnsAccessor.plot_cumulative`: ```pycon >>> ret_acc.plots().show() ``` ![](/assets/images/api/returns_plots.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/returns_plots.dark.svg#only-dark){: .iimg loading=lazy } """ import numpy as np import pandas as pd from pandas.tseries.offsets import BaseOffset from vectorbtpro import _typing as tp from vectorbtpro.accessors import register_vbt_accessor, register_df_vbt_accessor, register_sr_vbt_accessor from vectorbtpro.base.reshaping import to_1d_array, to_2d_array, broadcast_array_to, broadcast_to from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping from vectorbtpro.generic.accessors import GenericAccessor, GenericSRAccessor, GenericDFAccessor from vectorbtpro.generic.drawdowns import Drawdowns from vectorbtpro.generic.sim_range import SimRangeMixin from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.returns import nb from vectorbtpro.utils import checks, chunking as ch, datetime_ as dt from vectorbtpro.utils.config import resolve_dict, merge_dicts, HybridConfig, Config from vectorbtpro.utils.decorators import hybrid_property, hybrid_method from vectorbtpro.utils.warnings_ import warn __all__ = [ "ReturnsAccessor", "ReturnsSRAccessor", "ReturnsDFAccessor", ] __pdoc__ = {} ReturnsAccessorT = tp.TypeVar("ReturnsAccessorT", bound="ReturnsAccessor") @register_vbt_accessor("returns") class ReturnsAccessor(GenericAccessor, SimRangeMixin): """Accessor on top of return series. For both, Series and DataFrames. Accessible via `pd.Series.vbt.returns` and `pd.DataFrame.vbt.returns`. Args: obj (pd.Series or pd.DataFrame): Pandas object representing returns. bm_returns (array_like): Pandas object representing benchmark returns. log_returns (bool): Whether returns and benchmark returns are provided as log returns. year_freq (any): Year frequency for annualization purposes. defaults (dict): Defaults that override `defaults` in `vectorbtpro._settings.returns`. sim_start (int, datetime_like, or array_like): Simulation start per column. sim_end (int, datetime_like, or array_like): Simulation end per column. **kwargs: Keyword arguments that are passed down to `vectorbtpro.generic.accessors.GenericAccessor`.""" @classmethod def from_value( cls: tp.Type[ReturnsAccessorT], value: tp.ArrayLike, init_value: tp.ArrayLike = np.nan, log_returns: bool = False, sim_start: tp.Optional[tp.Array1d] = None, sim_end: tp.Optional[tp.Array1d] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrapper_kwargs: tp.KwargsLike = None, return_values: bool = False, **kwargs, ) -> tp.Union[ReturnsAccessorT, tp.SeriesFrame]: """Returns a new `ReturnsAccessor` instance with returns calculated from `value`.""" if wrapper_kwargs is None: wrapper_kwargs = {} if not checks.is_any_array(value): value = np.asarray(value) if wrapper is None: wrapper = ArrayWrapper.from_obj(value, **wrapper_kwargs) elif len(wrapper_kwargs) > 0: wrapper = wrapper.replace(**wrapper_kwargs) value = to_2d_array(value) init_value = broadcast_array_to(init_value, value.shape[1]) sim_start = cls.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False) sim_end = cls.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False) func = jit_reg.resolve_option(nb.returns_nb, jitted) func = ch_reg.resolve_option(func, chunked) returns = func( value, init_value=init_value, log_returns=log_returns, sim_start=sim_start, sim_end=sim_end, ) if return_values: return wrapper.wrap(returns, group_by=False) return cls(wrapper, returns, sim_start=sim_start, sim_end=sim_end, **kwargs) @classmethod def resolve_row_stack_kwargs( cls: tp.Type[ReturnsAccessorT], *objs: tp.MaybeTuple[ReturnsAccessorT], **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `ReturnsAccessor` after stacking along rows.""" kwargs = GenericAccessor.resolve_row_stack_kwargs(*objs, **kwargs) if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, ReturnsAccessor): raise TypeError("Each object to be merged must be an instance of ReturnsAccessor") if "bm_returns" not in kwargs: bm_returns = [] stack_bm_returns = True for obj in objs: if obj.config["bm_returns"] is not None: bm_returns.append(obj.config["bm_returns"]) else: stack_bm_returns = False break if stack_bm_returns: kwargs["bm_returns"] = kwargs["wrapper"].row_stack_arrs( *bm_returns, group_by=False, wrap=False, ) if "sim_start" not in kwargs: kwargs["sim_start"] = cls.row_stack_sim_start(kwargs["wrapper"], *objs) if "sim_end" not in kwargs: kwargs["sim_end"] = cls.row_stack_sim_end(kwargs["wrapper"], *objs) return kwargs @classmethod def resolve_column_stack_kwargs( cls: tp.Type[ReturnsAccessorT], *objs: tp.MaybeTuple[ReturnsAccessorT], reindex_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments for initializing `ReturnsAccessor` after stacking along columns.""" kwargs = GenericAccessor.resolve_column_stack_kwargs(*objs, reindex_kwargs=reindex_kwargs, **kwargs) kwargs.pop("reindex_kwargs", None) if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, ReturnsAccessor): raise TypeError("Each object to be merged must be an instance of ReturnsAccessor") if "bm_returns" not in kwargs: bm_returns = [] stack_bm_returns = True for obj in objs: if obj.bm_returns is not None: bm_returns.append(obj.bm_returns) else: stack_bm_returns = False break if stack_bm_returns: kwargs["bm_returns"] = kwargs["wrapper"].column_stack_arrs( *bm_returns, reindex_kwargs=reindex_kwargs, group_by=False, wrap=False, ) if "sim_start" not in kwargs: kwargs["sim_start"] = cls.column_stack_sim_start(kwargs["wrapper"], *objs) if "sim_end" not in kwargs: kwargs["sim_end"] = cls.column_stack_sim_end(kwargs["wrapper"], *objs) return kwargs def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, log_returns: bool = False, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, sim_start: tp.Optional[tp.Array1d] = None, sim_end: tp.Optional[tp.Array1d] = None, **kwargs, ) -> None: GenericAccessor.__init__( self, wrapper, obj=obj, bm_returns=bm_returns, log_returns=log_returns, year_freq=year_freq, defaults=defaults, sim_start=sim_start, sim_end=sim_end, **kwargs, ) SimRangeMixin.__init__(self, sim_start=sim_start, sim_end=sim_end) self._bm_returns = bm_returns self._log_returns = log_returns self._year_freq = year_freq self._defaults = defaults @hybrid_property def sr_accessor_cls(cls_or_self) -> tp.Type["ReturnsSRAccessor"]: """Accessor class for `pd.Series`.""" return ReturnsSRAccessor @hybrid_property def df_accessor_cls(cls_or_self) -> tp.Type["ReturnsDFAccessor"]: """Accessor class for `pd.DataFrame`.""" return ReturnsDFAccessor def indexing_func( self: ReturnsAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs, ) -> ReturnsAccessorT: """Perform indexing on `ReturnsAccessor`.""" if wrapper_meta is None: wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs) new_obj = wrapper_meta["new_wrapper"].wrap( self.to_2d_array()[wrapper_meta["row_idxs"], :][:, wrapper_meta["col_idxs"]], group_by=False, ) if self._bm_returns is not None: new_bm_returns = ArrayWrapper.select_from_flex_array( self._bm_returns, row_idxs=wrapper_meta["row_idxs"], col_idxs=wrapper_meta["col_idxs"], rows_changed=wrapper_meta["rows_changed"], columns_changed=wrapper_meta["columns_changed"], ) else: new_bm_returns = None new_sim_start = self.sim_start_indexing_func(wrapper_meta) new_sim_end = self.sim_end_indexing_func(wrapper_meta) if checks.is_series(new_obj): return self.replace( cls_=self.sr_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj, bm_returns=new_bm_returns, sim_start=new_sim_start, sim_end=new_sim_end, ) return self.replace( cls_=self.df_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj, bm_returns=new_bm_returns, sim_start=new_sim_start, sim_end=new_sim_end, ) # ############# Properties ############# # @property def bm_returns(self) -> tp.Optional[tp.SeriesFrame]: """Benchmark returns.""" from vectorbtpro._settings import settings returns_cfg = settings["returns"] bm_returns = self._bm_returns if bm_returns is None: bm_returns = returns_cfg["bm_returns"] if bm_returns is not None: bm_returns = self.wrapper.wrap(bm_returns, group_by=False) return bm_returns def get_bm_returns_acc( self, bm_returns: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, ) -> tp.Optional[ReturnsAccessorT]: """Get accessor for benchmark returns.""" if bm_returns is None: bm_returns = self.bm_returns if bm_returns is None: return None sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) return self.replace( obj=bm_returns, bm_returns=None, sim_start=sim_start, sim_end=sim_end, ) @property def bm_returns_acc(self) -> tp.Optional[ReturnsAccessorT]: """`ReturnsAccessor.get_bm_returns_acc` with default arguments.""" return self.get_bm_returns_acc() @property def log_returns(self) -> bool: """Whether returns and benchmark returns are provided as log returns.""" return self._log_returns @classmethod def auto_detect_ann_factor(cls, index: pd.DatetimeIndex) -> tp.Optional[float]: """Auto-detect annualization factor from a datetime index.""" checks.assert_instance_of(index, pd.DatetimeIndex, arg_name="index") if len(index) == 1: return None offset = index[0] + pd.offsets.YearBegin() - index[0] first_date = index[0] + offset last_date = index[-1] + offset next_year_date = last_date + pd.offsets.YearBegin() ratio = (last_date.value - first_date.value) / (next_year_date.value - first_date.value) ann_factor = len(index) / ratio ann_factor /= next_year_date.year - first_date.year return ann_factor @classmethod def parse_ann_factor(cls, index: pd.DatetimeIndex, method_name: str = "max") -> tp.Optional[float]: """Parse annualization factor from a datetime index.""" checks.assert_instance_of(index, pd.DatetimeIndex, arg_name="index") if len(index) == 1: return None offset = index[0] + pd.offsets.YearBegin() - index[0] shifted_index = index + offset years = shifted_index.year full_years = years[years < years.max()] if len(full_years) == 0: return None return getattr(full_years.value_counts(), method_name.lower())() @classmethod def ann_factor_to_year_freq( cls, ann_factor: float, freq: tp.PandasFrequency, method_name: tp.Optional[str] = None, ) -> tp.PandasFrequency: """Convert annualization factor into year frequency.""" if method_name not in (None, False): if method_name is True: ann_factor = round(ann_factor) else: ann_factor = getattr(np, method_name.lower())(ann_factor) if checks.is_float(ann_factor) and float.is_integer(ann_factor): ann_factor = int(ann_factor) if checks.is_float(ann_factor) and isinstance(freq, BaseOffset): freq = dt.offset_to_timedelta(freq) return ann_factor * freq @classmethod def year_freq_depends_on_index(cls, year_freq: tp.FrequencyLike) -> bool: """Return whether frequency depends on index.""" if isinstance(year_freq, str): year_freq = " ".join(year_freq.strip().split()) if year_freq == "auto" or year_freq.startswith("auto_"): return True if year_freq.startswith("index_"): return True return False @hybrid_method def get_year_freq( cls_or_self, year_freq: tp.Optional[tp.FrequencyLike] = None, index: tp.Optional[tp.Index] = None, freq: tp.Optional[tp.PandasFrequency] = None, ) -> tp.Optional[tp.PandasFrequency]: """Resolve year frequency. If `year_freq` is "auto", uses `ReturnsAccessor.auto_detect_ann_factor`. If `year_freq` is "auto_[method_name]`, also applies the method `np.[method_name]` to the annualization factor, mostly to round it. If `year_freq` is "index_[method_name]", uses `ReturnsAccessor.parse_ann_factor` to determine the annualization factor by applying the method to `pd.DatetimeIndex.year`.""" if not isinstance(cls_or_self, type): if year_freq is None: year_freq = cls_or_self._year_freq if year_freq is None: from vectorbtpro._settings import settings returns_cfg = settings["returns"] year_freq = returns_cfg["year_freq"] if year_freq is None: return None if isinstance(year_freq, str): year_freq = " ".join(year_freq.strip().split()) if cls_or_self.year_freq_depends_on_index(year_freq): if not isinstance(cls_or_self, type): if index is None: index = cls_or_self.wrapper.index if freq is None: freq = cls_or_self.wrapper.freq if index is None or not isinstance(index, pd.DatetimeIndex) or freq is None: return None if year_freq == "auto" or year_freq.startswith("auto_"): ann_factor = cls_or_self.auto_detect_ann_factor(index) if year_freq == "auto": method_name = None else: method_name = year_freq.replace("auto_", "") year_freq = cls_or_self.ann_factor_to_year_freq( ann_factor, dt.to_freq(freq), method_name=method_name, ) else: method_name = year_freq.replace("index_", "") ann_factor = cls_or_self.parse_ann_factor(index, method_name=method_name) year_freq = cls_or_self.ann_factor_to_year_freq( ann_factor, dt.to_freq(freq), method_name=None, ) return dt.to_freq(year_freq) @property def year_freq(self) -> tp.Optional[tp.PandasFrequency]: """Year frequency.""" return self.get_year_freq() @hybrid_method def get_ann_factor( cls_or_self, year_freq: tp.Optional[tp.FrequencyLike] = None, freq: tp.Optional[tp.FrequencyLike] = None, raise_error: bool = False, ) -> tp.Optional[float]: """Get the annualization factor from the year and data frequency.""" if isinstance(cls_or_self, type): from vectorbtpro._settings import settings returns_cfg = settings["returns"] wrapping_cfg = settings["wrapping"] if year_freq is None: year_freq = returns_cfg["year_freq"] if freq is None: freq = wrapping_cfg["freq"] if freq is not None and dt.freq_depends_on_index(freq): freq = None else: if year_freq is None: year_freq = cls_or_self.year_freq if freq is None: freq = cls_or_self.wrapper.freq if year_freq is None: if not raise_error: return None raise ValueError( "Year frequency is None. " "Pass it as `year_freq` or define it globally under `settings.returns`. " "To determine year frequency automatically, use 'auto'." ) if freq is None: if not raise_error: return None raise ValueError( "Index frequency is None. " "Pass it as `freq` or define it globally under `settings.wrapping`. " "To determine frequency automatically, use 'auto'." ) return dt.to_timedelta(year_freq, approximate=True) / dt.to_timedelta(freq, approximate=True) @property def ann_factor(self) -> float: """Annualization factor.""" return self.get_ann_factor(raise_error=True) @hybrid_method def get_periods( cls_or_self, periods: tp.Union[None, str, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, wrapper: tp.Optional[ArrayWrapper] = None, group_by: tp.GroupByLike = None, ) -> tp.Optional[tp.ArrayLike]: """Prepare periods.""" if not isinstance(cls_or_self, type) and periods is None: periods = cls_or_self.defaults["periods"] if isinstance(periods, str) and periods.lower() == "dt_periods": if not isinstance(cls_or_self, type): if wrapper is None: wrapper = cls_or_self.wrapper else: checks.assert_not_none(wrapper, arg_name="wrapper") sim_start = cls_or_self.resolve_sim_start( sim_start=sim_start, allow_none=True, wrapper=wrapper, group_by=group_by, ) sim_end = cls_or_self.resolve_sim_end( sim_end=sim_end, allow_none=True, wrapper=wrapper, group_by=group_by, ) if sim_start is not None or sim_end is not None: if sim_start is None: sim_start = cls_or_self.resolve_sim_start( sim_start=sim_start, allow_none=False, wrapper=wrapper, group_by=group_by, ) if sim_end is None: sim_end = cls_or_self.resolve_sim_end( sim_end=sim_end, allow_none=False, wrapper=wrapper, group_by=group_by, ) periods = [] for i in range(len(sim_start)): sim_index = wrapper.index[sim_start[i] : sim_end[i]] if len(sim_index) == 0: periods.append(0) else: periods.append(wrapper.index_acc.get_dt_periods(index=sim_index)) periods = np.asarray(periods) else: periods = wrapper.dt_periods return periods @property def periods(self) -> tp.Optional[tp.ArrayLike]: """Periods.""" return self.get_periods() def deannualize(self, value: float) -> float: """Deannualize a value.""" return np.power(1 + value, 1.0 / self.ann_factor) - 1.0 @property def defaults(self) -> tp.Kwargs: """Defaults for `ReturnsAccessor`. Merges `defaults` from `vectorbtpro._settings.returns` with `defaults` from `ReturnsAccessor.__init__`.""" from vectorbtpro._settings import settings returns_defaults_cfg = settings["returns"]["defaults"] return merge_dicts(returns_defaults_cfg, self._defaults) # ############# Transforming ############# # def mirror( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Mirror returns. See `vectorbtpro.returns.nb.mirror_returns_nb`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.mirror_returns_nb, jitted) func = ch_reg.resolve_option(func, chunked) mirrored_returns = func( self.to_2d_array(), log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(mirrored_returns, group_by=False, **resolve_dict(wrap_kwargs)) def cumulative( self, start_value: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Cumulative returns. See `vectorbtpro.returns.nb.cumulative_returns_nb`.""" if start_value is None: start_value = self.defaults["start_value"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.cumulative_returns_nb, jitted) func = ch_reg.resolve_option(func, chunked) cumulative = func( self.to_2d_array(), start_value=start_value, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(cumulative, group_by=False, **resolve_dict(wrap_kwargs)) def resample( self: ReturnsAccessorT, *args, fill_with_zero: bool = True, wrapper_meta: tp.DictLike = None, **kwargs, ) -> ReturnsAccessorT: """Perform resampling on `ReturnsAccessor`.""" if wrapper_meta is None: wrapper_meta = self.wrapper.resample_meta(*args, **kwargs) new_wrapper = wrapper_meta["new_wrapper"] new_obj = self.resample_apply( wrapper_meta["resampler"], nb.total_return_1d_nb, self.log_returns, ) if fill_with_zero: new_obj = new_obj.vbt.fillna(0.0) if self._bm_returns is not None: new_bm_returns = self.bm_returns.vbt.resample_apply( wrapper_meta["resampler"], nb.total_return_1d_nb, self.log_returns, ) if fill_with_zero: new_bm_returns = new_bm_returns.vbt.fillna(0.0) else: new_bm_returns = None new_sim_start = self.resample_sim_start(new_wrapper) new_sim_end = self.resample_sim_end(new_wrapper) return self.replace( wrapper=wrapper_meta["new_wrapper"], obj=new_obj, bm_returns=new_bm_returns, sim_start=new_sim_start, sim_end=new_sim_end, ) def resample_returns( self, rule: tp.AnyRuleLike, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.SeriesFrame: """Resample returns to a custom frequency, date offset, or index.""" checks.assert_instance_of(self.obj.index, dt.PandasDatetimeIndex) func = jit_reg.resolve_option(nb.total_return_1d_nb, jitted) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( args=ch.ArgsTaker( None, ) ), ) return self.resample_apply( rule, func, self.log_returns, jitted=jitted, chunked=chunked, **kwargs, ) def daily( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.SeriesFrame: """Daily returns.""" return self.resample_returns("1D", jitted=jitted, chunked=chunked, **kwargs) def annual( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.SeriesFrame: """Annual returns.""" return self.resample_returns(self.year_freq, jitted=jitted, chunked=chunked, **kwargs) # ############# Metrics ############# # def final_value( self, start_value: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Final value. See `vectorbtpro.returns.nb.final_value_nb`.""" if start_value is None: start_value = self.defaults["start_value"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.final_value_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), start_value=start_value, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="final_value"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_final_value( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, start_value: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Rolling final value. See `vectorbtpro.returns.nb.rolling_final_value_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if start_value is None: start_value = self.defaults["start_value"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_final_value_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, start_value=start_value, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def total( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Total return. See `vectorbtpro.returns.nb.total_return_nb`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.total_return_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="total_return"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_total( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Rolling total return. See `vectorbtpro.returns.nb.rolling_total_return_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_total_return_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def annualized( self, periods: tp.Union[None, str, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Annualized return. See `vectorbtpro.returns.nb.annualized_return_nb`.""" periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.annualized_return_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), self.ann_factor, periods=periods, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="annualized_return"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_annualized( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling annualized return. See `vectorbtpro.returns.nb.rolling_annualized_return_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_annualized_return_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, self.ann_factor, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def annualized_volatility( self, levy_alpha: tp.Optional[float] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Annualized volatility. See `vectorbtpro.returns.nb.annualized_volatility_nb`.""" if levy_alpha is None: levy_alpha = self.defaults["levy_alpha"] if ddof is None: ddof = self.defaults["ddof"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.annualized_volatility_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), self.ann_factor, levy_alpha=levy_alpha, ddof=ddof, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="annualized_volatility"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_annualized_volatility( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, levy_alpha: tp.Optional[float] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling annualized volatility. See `vectorbtpro.returns.nb.rolling_annualized_volatility_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if levy_alpha is None: levy_alpha = self.defaults["levy_alpha"] if ddof is None: ddof = self.defaults["ddof"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_annualized_volatility_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, self.ann_factor, levy_alpha=levy_alpha, ddof=ddof, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def calmar_ratio( self, periods: tp.Union[None, str, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Calmar ratio. See `vectorbtpro.returns.nb.calmar_ratio_nb`.""" periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.calmar_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), self.ann_factor, periods=periods, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="calmar_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_calmar_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Calmar ratio. See `vectorbtpro.returns.nb.rolling_calmar_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_calmar_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, self.ann_factor, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def omega_ratio( self, risk_free: tp.Optional[float] = None, required_return: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Omega ratio. See `vectorbtpro.returns.nb.omega_ratio_nb`.""" if risk_free is None: risk_free = self.defaults["risk_free"] if required_return is None: required_return = self.defaults["required_return"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.deannualized_return_nb, jitted) required_return = func(required_return, self.ann_factor) func = jit_reg.resolve_option(nb.omega_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - risk_free - required_return, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="omega_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_omega_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, risk_free: tp.Optional[float] = None, required_return: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Omega ratio. See `vectorbtpro.returns.nb.rolling_omega_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if risk_free is None: risk_free = self.defaults["risk_free"] if required_return is None: required_return = self.defaults["required_return"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.deannualized_return_nb, jitted) required_return = func(required_return, self.ann_factor) func = jit_reg.resolve_option(nb.rolling_omega_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - risk_free - required_return, window, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def sharpe_ratio( self, annualized: bool = True, risk_free: tp.Optional[float] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Sharpe ratio. See `vectorbtpro.returns.nb.sharpe_ratio_nb`.""" if risk_free is None: risk_free = self.defaults["risk_free"] if ddof is None: ddof = self.defaults["ddof"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.sharpe_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) if annualized: ann_factor = self.ann_factor else: ann_factor = 1 out = func( self.to_2d_array() - risk_free, ann_factor, ddof=ddof, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="sharpe_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_sharpe_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, annualized: bool = True, risk_free: tp.Optional[float] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, stream_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Sharpe ratio. See `vectorbtpro.returns.nb.rolling_sharpe_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if risk_free is None: risk_free = self.defaults["risk_free"] if ddof is None: ddof = self.defaults["ddof"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_sharpe_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) if annualized: ann_factor = self.ann_factor else: ann_factor = 1 out = func( self.to_2d_array() - risk_free, window, ann_factor, ddof=ddof, minp=minp, sim_start=sim_start, sim_end=sim_end, stream_mode=stream_mode, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def sharpe_ratio_std( self, risk_free: tp.Optional[float] = None, ddof: tp.Optional[int] = None, bias: bool = True, wrap_kwargs: tp.KwargsLike = None, ): """Standard deviation of the sharpe ratio estimation.""" from scipy import stats as scipy_stats returns = to_2d_array(self.obj) nanmask = np.isnan(returns) if nanmask.any(): returns = returns.copy() returns[nanmask] = 0.0 n = len(returns) skew = scipy_stats.skew(returns, axis=0, bias=bias) kurtosis = scipy_stats.kurtosis(returns, axis=0, bias=bias) sr = to_1d_array(self.sharpe_ratio(annualized=False, risk_free=risk_free, ddof=ddof)) out = np.sqrt((1 + (0.5 * sr**2) - (skew * sr) + (((kurtosis - 3) / 4) * sr**2)) / (n - 1)) wrap_kwargs = merge_dicts(dict(name_or_index="sharpe_ratio_std"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def prob_sharpe_ratio( self, bm_returns: tp.Optional[tp.ArrayLike] = None, risk_free: tp.Optional[float] = None, ddof: tp.Optional[int] = None, bias: bool = True, wrap_kwargs: tp.KwargsLike = None, ): """Probabilistic Sharpe Ratio (PSR).""" from scipy import stats as scipy_stats if bm_returns is None: bm_returns = self.bm_returns if bm_returns is not None: bm_sr = to_1d_array( self.replace(obj=bm_returns, bm_returns=None).sharpe_ratio( annualized=False, risk_free=risk_free, ddof=ddof, ) ) else: bm_sr = 0 sr = to_1d_array(self.sharpe_ratio(annualized=False, risk_free=risk_free, ddof=ddof)) sr_std = to_1d_array(self.sharpe_ratio_std(risk_free=risk_free, ddof=ddof, bias=bias)) out = scipy_stats.norm.cdf((sr - bm_sr) / sr_std) wrap_kwargs = merge_dicts(dict(name_or_index="prob_sharpe_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def deflated_sharpe_ratio( self, risk_free: tp.Optional[float] = None, ddof: tp.Optional[int] = None, bias: bool = True, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Deflated Sharpe Ratio (DSR). Expresses the chance that the advertised strategy has a positive Sharpe ratio.""" from scipy import stats as scipy_stats if risk_free is None: risk_free = self.defaults["risk_free"] if ddof is None: ddof = self.defaults["ddof"] sharpe_ratio = to_1d_array(self.sharpe_ratio(annualized=False, risk_free=risk_free, ddof=ddof)) var_sharpe = np.nanvar(sharpe_ratio, ddof=ddof) returns = to_2d_array(self.obj) nanmask = np.isnan(returns) if nanmask.any(): returns = returns.copy() returns[nanmask] = 0.0 skew = scipy_stats.skew(returns, axis=0, bias=bias) kurtosis = scipy_stats.kurtosis(returns, axis=0, bias=bias) SR0 = sharpe_ratio + np.sqrt(var_sharpe) * ( (1 - np.euler_gamma) * scipy_stats.norm.ppf(1 - 1 / self.wrapper.shape_2d[1]) + np.euler_gamma * scipy_stats.norm.ppf(1 - 1 / (self.wrapper.shape_2d[1] * np.e)) ) out = scipy_stats.norm.cdf( ((sharpe_ratio - SR0) * np.sqrt(self.wrapper.shape_2d[0] - 1)) / np.sqrt(1 - skew * sharpe_ratio + ((kurtosis - 1) / 4) * sharpe_ratio**2) ) wrap_kwargs = merge_dicts(dict(name_or_index="deflated_sharpe_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def downside_risk( self, required_return: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Downside risk. See `vectorbtpro.returns.nb.downside_risk_nb`.""" if required_return is None: required_return = self.defaults["required_return"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.downside_risk_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - required_return, self.ann_factor, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="downside_risk"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_downside_risk( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, required_return: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling downside risk. See `vectorbtpro.returns.nb.rolling_downside_risk_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if required_return is None: required_return = self.defaults["required_return"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_downside_risk_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - required_return, window, self.ann_factor, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def sortino_ratio( self, required_return: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Sortino ratio. See `vectorbtpro.returns.nb.sortino_ratio_nb`.""" if required_return is None: required_return = self.defaults["required_return"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.sortino_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - required_return, self.ann_factor, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="sortino_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_sortino_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, required_return: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Sortino ratio. See `vectorbtpro.returns.nb.rolling_sortino_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if required_return is None: required_return = self.defaults["required_return"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_sortino_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - required_return, window, self.ann_factor, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def information_ratio( self, bm_returns: tp.Optional[tp.ArrayLike] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Information ratio. See `vectorbtpro.returns.nb.information_ratio_nb`.""" if ddof is None: ddof = self.defaults["ddof"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.information_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - to_2d_array(bm_returns), ddof=ddof, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="information_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_information_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling information ratio. See `vectorbtpro.returns.nb.rolling_information_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if ddof is None: ddof = self.defaults["ddof"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_information_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - to_2d_array(bm_returns), window, ddof=ddof, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def beta( self, bm_returns: tp.Optional[tp.ArrayLike] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Beta. See `vectorbtpro.returns.nb.beta_nb`.""" if ddof is None: ddof = self.defaults["ddof"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.beta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), ddof=ddof, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="beta"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_beta( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, ddof: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling beta. See `vectorbtpro.returns.nb.rolling_beta_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if ddof is None: ddof = self.defaults["ddof"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_beta_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), window, ddof=ddof, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def alpha( self, bm_returns: tp.Optional[tp.ArrayLike] = None, risk_free: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Alpha. See `vectorbtpro.returns.nb.alpha_nb`.""" if risk_free is None: risk_free = self.defaults["risk_free"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.alpha_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - risk_free, to_2d_array(bm_returns) - risk_free, self.ann_factor, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="alpha"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_alpha( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, risk_free: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling alpha. See `vectorbtpro.returns.nb.rolling_alpha_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if risk_free is None: risk_free = self.defaults["risk_free"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_alpha_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array() - risk_free, to_2d_array(bm_returns) - risk_free, window, self.ann_factor, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def tail_ratio( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, noarr_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Tail ratio. See `vectorbtpro.returns.nb.tail_ratio_nb`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.tail_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), sim_start=sim_start, sim_end=sim_end, noarr_mode=noarr_mode, ) wrap_kwargs = merge_dicts(dict(name_or_index="tail_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_tail_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, noarr_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling tail ratio. See `vectorbtpro.returns.nb.rolling_tail_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_tail_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, minp=minp, sim_start=sim_start, sim_end=sim_end, noarr_mode=noarr_mode, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def profit_factor( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Profit factor. See `vectorbtpro.returns.nb.profit_factor_nb`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.profit_factor_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="profit_factor"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_profit_factor( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling profit factor. See `vectorbtpro.returns.nb.rolling_profit_factor_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_profit_factor_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def common_sense_ratio( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Common Sense Ratio (CSR). See `vectorbtpro.returns.nb.common_sense_ratio_nb`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.common_sense_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="common_sense_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_common_sense_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Common Sense Ratio (CSR). See `vectorbtpro.returns.nb.rolling_common_sense_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_common_sense_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def value_at_risk( self, cutoff: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, noarr_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Value at Risk (VaR). See `vectorbtpro.returns.nb.value_at_risk_nb`.""" if cutoff is None: cutoff = self.defaults["cutoff"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.value_at_risk_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), cutoff=cutoff, sim_start=sim_start, sim_end=sim_end, noarr_mode=noarr_mode, ) wrap_kwargs = merge_dicts(dict(name_or_index="value_at_risk"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_value_at_risk( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, cutoff: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, noarr_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Value at Risk (VaR). See `vectorbtpro.returns.nb.rolling_value_at_risk_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if cutoff is None: cutoff = self.defaults["cutoff"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_value_at_risk_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, cutoff=cutoff, minp=minp, sim_start=sim_start, sim_end=sim_end, noarr_mode=noarr_mode, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def cond_value_at_risk( self, cutoff: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, noarr_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Conditional Value at Risk (CVaR). See `vectorbtpro.returns.nb.cond_value_at_risk_nb`.""" if cutoff is None: cutoff = self.defaults["cutoff"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.cond_value_at_risk_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), cutoff=cutoff, sim_start=sim_start, sim_end=sim_end, noarr_mode=noarr_mode, ) wrap_kwargs = merge_dicts(dict(name_or_index="cond_value_at_risk"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_cond_value_at_risk( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, cutoff: tp.Optional[float] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, noarr_mode: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Conditional Value at Risk (CVaR). See `vectorbtpro.returns.nb.rolling_cond_value_at_risk_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if cutoff is None: cutoff = self.defaults["cutoff"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_cond_value_at_risk_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, cutoff=cutoff, minp=minp, sim_start=sim_start, sim_end=sim_end, noarr_mode=noarr_mode, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def capture_ratio( self, bm_returns: tp.Optional[tp.ArrayLike] = None, periods: tp.Union[None, str, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Capture ratio. See `vectorbtpro.returns.nb.capture_ratio_nb`.""" if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.capture_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), self.ann_factor, periods=periods, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="capture_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_capture_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling capture ratio. See `vectorbtpro.returns.nb.rolling_capture_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_capture_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), window, self.ann_factor, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def up_capture_ratio( self, bm_returns: tp.Optional[tp.ArrayLike] = None, periods: tp.Union[None, str, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Up-market capture ratio. See `vectorbtpro.returns.nb.up_capture_ratio_nb`.""" if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.up_capture_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), self.ann_factor, periods=periods, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="up_capture_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_up_capture_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling up-market capture ratio. See `vectorbtpro.returns.nb.rolling_up_capture_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_up_capture_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), window, self.ann_factor, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def down_capture_ratio( self, bm_returns: tp.Optional[tp.ArrayLike] = None, periods: tp.Union[None, str, tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Up-market capture ratio. See `vectorbtpro.returns.nb.down_capture_ratio_nb`.""" if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.down_capture_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), self.ann_factor, periods=periods, log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="down_capture_ratio"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_down_capture_ratio( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling down-market capture ratio. See `vectorbtpro.returns.nb.rolling_down_capture_ratio_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] if bm_returns is None: bm_returns = self.bm_returns checks.assert_not_none(bm_returns, arg_name="bm_returns") bm_returns = broadcast_to(bm_returns, self.obj) sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_down_capture_ratio_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), to_2d_array(bm_returns), window, self.ann_factor, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) def drawdown( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Relative decline from a peak.""" return self.cumulative( start_value=1, sim_start=sim_start, sim_end=sim_end, jitted=jitted, chunked=chunked, ).vbt.drawdown( jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, ) def max_drawdown( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Maximum Drawdown (MDD). See `vectorbtpro.returns.nb.max_drawdown_nb`. Yields the same out as `max_drawdown` of `ReturnsAccessor.drawdowns`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.max_drawdown_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), log_returns=self.log_returns, sim_start=sim_start, sim_end=sim_end, ) wrap_kwargs = merge_dicts(dict(name_or_index="max_drawdown"), wrap_kwargs) return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs) def rolling_max_drawdown( self, window: tp.Optional[int] = None, *, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """Rolling Maximum Drawdown (MDD). See `vectorbtpro.returns.nb.rolling_max_drawdown_nb`.""" if window is None: window = self.defaults["window"] if minp is None: minp = self.defaults["minp"] sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) func = jit_reg.resolve_option(nb.rolling_max_drawdown_nb, jitted) func = ch_reg.resolve_option(func, chunked) out = func( self.to_2d_array(), window, log_returns=self.log_returns, minp=minp, sim_start=sim_start, sim_end=sim_end, ) return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs)) @property def drawdowns(self) -> Drawdowns: """`ReturnsAccessor.get_drawdowns` with default arguments.""" return self.get_drawdowns() def get_drawdowns( self, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> Drawdowns: """Generate drawdown records of cumulative returns. See `vectorbtpro.generic.drawdowns.Drawdowns`.""" sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False) sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False) return Drawdowns.from_price( self.cumulative( start_value=1.0, sim_start=sim_start, sim_end=sim_end, jitted=jitted, ), sim_start=sim_start, sim_end=sim_end, wrapper=self.wrapper, **kwargs, ) @property def qs(self) -> "QSAdapter": """Quantstats adapter.""" from vectorbtpro.returns.qs_adapter import QSAdapter return QSAdapter(self) # ############# Resolution ############# # def resolve_self( self: ReturnsAccessorT, cond_kwargs: tp.KwargsLike = None, custom_arg_names: tp.Optional[tp.Set[str]] = None, impacts_caching: bool = True, silence_warnings: bool = False, ) -> ReturnsAccessorT: """Resolve self. See `vectorbtpro.base.wrapping.Wrapping.resolve_self`. Creates a copy of this instance `year_freq` is different in `cond_kwargs`.""" if cond_kwargs is None: cond_kwargs = {} if custom_arg_names is None: custom_arg_names = set() reself = Wrapping.resolve_self( self, cond_kwargs=cond_kwargs, custom_arg_names=custom_arg_names, impacts_caching=impacts_caching, silence_warnings=silence_warnings, ) if "year_freq" in cond_kwargs: self_copy = reself.replace(year_freq=cond_kwargs["year_freq"]) if self_copy.year_freq != reself.year_freq: if not silence_warnings: warn( f"Changing the year frequency will create a copy of this object. " f"Consider setting it upon object creation to re-use existing cache." ) for alias in reself.self_aliases: if alias not in custom_arg_names: cond_kwargs[alias] = self_copy cond_kwargs["year_freq"] = self_copy.year_freq if impacts_caching: cond_kwargs["use_caching"] = False return self_copy return reself # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `ReturnsAccessor.stats`. Merges `vectorbtpro.generic.accessors.GenericAccessor.stats_defaults`, defaults from `ReturnsAccessor.defaults` (acting as `settings`), and `stats` from `vectorbtpro._settings.returns`""" from vectorbtpro._settings import settings returns_stats_cfg = settings["returns"]["stats"] return merge_dicts( GenericAccessor.stats_defaults.__get__(self), dict(settings=self.defaults), dict(settings=dict(year_freq=self.year_freq)), returns_stats_cfg, ) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func="sim_start_index", tags="wrapper", ), end_index=dict( title="End Index", calc_func="sim_end_index", tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func="sim_duration", apply_to_timedelta=True, tags="wrapper", ), total_return=dict( title="Total Return [%]", calc_func="total", post_calc_func=lambda self, out, settings: out * 100, tags="returns", ), bm_return=dict( title="Benchmark Return [%]", calc_func="bm_returns_acc.total", post_calc_func=lambda self, out, settings: out * 100, check_has_bm_returns=True, tags="returns", ), ann_return=dict( title="Annualized Return [%]", calc_func="annualized", post_calc_func=lambda self, out, settings: out * 100, check_has_freq=True, check_has_year_freq=True, tags="returns", ), ann_volatility=dict( title="Annualized Volatility [%]", calc_func="annualized_volatility", post_calc_func=lambda self, out, settings: out * 100, check_has_freq=True, check_has_year_freq=True, tags="returns", ), max_dd=dict( title="Max Drawdown [%]", calc_func="drawdowns.get_max_drawdown", post_calc_func=lambda self, out, settings: -out * 100, tags=["returns", "drawdowns"], ), max_dd_duration=dict( title="Max Drawdown Duration", calc_func="drawdowns.get_max_duration", fill_wrap_kwargs=True, tags=["returns", "drawdowns", "duration"], ), sharpe_ratio=dict( title="Sharpe Ratio", calc_func="sharpe_ratio", check_has_freq=True, check_has_year_freq=True, tags="returns", ), calmar_ratio=dict( title="Calmar Ratio", calc_func="calmar_ratio", check_has_freq=True, check_has_year_freq=True, tags="returns", ), omega_ratio=dict( title="Omega Ratio", calc_func="omega_ratio", check_has_freq=True, check_has_year_freq=True, tags="returns", ), sortino_ratio=dict( title="Sortino Ratio", calc_func="sortino_ratio", check_has_freq=True, check_has_year_freq=True, tags="returns", ), skew=dict( title="Skew", calc_func="obj.skew", tags="returns", ), kurtosis=dict( title="Kurtosis", calc_func="obj.kurtosis", tags="returns", ), tail_ratio=dict( title="Tail Ratio", calc_func="tail_ratio", tags="returns", ), common_sense_ratio=dict( title="Common Sense Ratio", calc_func="common_sense_ratio", check_has_freq=True, check_has_year_freq=True, tags="returns", ), value_at_risk=dict( title="Value at Risk", calc_func="value_at_risk", tags="returns", ), alpha=dict( title="Alpha", calc_func="alpha", check_has_freq=True, check_has_year_freq=True, check_has_bm_returns=True, tags="returns", ), beta=dict( title="Beta", calc_func="beta", check_has_bm_returns=True, tags="returns", ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot_cumulative( self, column: tp.Optional[tp.Label] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, start_value: float = 1, sim_start: tp.Optional[tp.ArrayLike] = None, sim_end: tp.Optional[tp.ArrayLike] = None, fit_sim_range: bool = True, fill_to_benchmark: bool = False, main_kwargs: tp.KwargsLike = None, bm_kwargs: tp.KwargsLike = None, pct_scale: bool = False, hline_shape_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, xref: str = "x", yref: str = "y", fig: tp.Optional[tp.BaseFigure] = None, **layout_kwargs, ) -> tp.BaseFigure: """Plot cumulative returns. Args: column (str): Name of the column to plot. bm_returns (array_like): Benchmark return to compare returns against. Will broadcast per element. start_value (float): The starting value. sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive). sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive). fit_sim_range (bool): Whether to fit figure to simulation range. fill_to_benchmark (bool): Whether to fill between main and benchmark, or between main and `start_value`. main_kwargs (dict): Keyword arguments passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot` for main. bm_kwargs (dict): Keyword arguments passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot` for benchmark. pct_scale (bool): Whether to use the percentage scale for the y-axis. hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for `start_value` line. add_trace_kwargs (dict): Keyword arguments passed to `add_trace`. xref (str): X coordinate axis. yref (str): Y coordinate axis. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout. Usage: ```pycon >>> np.random.seed(0) >>> rets = pd.Series(np.random.uniform(-0.05, 0.05, size=100)) >>> bm_returns = pd.Series(np.random.uniform(-0.05, 0.05, size=100)) >>> rets.vbt.returns.plot_cumulative(bm_returns=bm_returns).show() ``` ![](/assets/images/api/plot_cumulative.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/plot_cumulative.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro.utils.figure import make_figure, get_domain from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] xaxis = "xaxis" + xref[1:] yaxis = "yaxis" + yref[1:] def_layout_kwargs = {xaxis: {}, yaxis: {}} if pct_scale: start_value = 0 def_layout_kwargs[yaxis]["tickformat"] = ".2%" if fig is None: fig = make_figure() fig.update_layout(**def_layout_kwargs) fig.update_layout(**layout_kwargs) x_domain = get_domain(xref, fig) y_domain = get_domain(yref, fig) if bm_returns is None: bm_returns = self.bm_returns fill_to_benchmark = fill_to_benchmark and bm_returns is not None if bm_returns is not None: # Plot benchmark bm_returns = broadcast_to(bm_returns, self.obj) bm_returns = self.select_col_from_obj(bm_returns, column=column, group_by=False) if bm_kwargs is None: bm_kwargs = {} bm_kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["gray"], ), name="Benchmark", ) ), bm_kwargs, ) bm_cumulative_returns = bm_returns.vbt.returns.cumulative( start_value=start_value, sim_start=sim_start, sim_end=sim_end, ) bm_cumulative_returns.vbt.lineplot(**bm_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig) else: bm_cumulative_returns = None if main_kwargs is None: main_kwargs = {} cumulative_returns = self.cumulative( start_value=start_value, sim_start=sim_start, sim_end=sim_end, ) cumulative_returns = self.select_col_from_obj(cumulative_returns, column=column, group_by=False) main_kwargs = merge_dicts( dict( trace_kwargs=dict( line=dict( color=plotting_cfg["color_schema"]["purple"], ), ), other_trace_kwargs="hidden", ), main_kwargs, ) if fill_to_benchmark: cumulative_returns.vbt.plot_against( bm_cumulative_returns, add_trace_kwargs=add_trace_kwargs, fig=fig, **main_kwargs ) else: cumulative_returns.vbt.plot_against(start_value, add_trace_kwargs=add_trace_kwargs, fig=fig, **main_kwargs) if hline_shape_kwargs is None: hline_shape_kwargs = {} fig.add_shape( **merge_dicts( dict( type="line", xref="paper", yref=yref, x0=x_domain[0], y0=start_value, x1=x_domain[1], y1=start_value, line=dict( color="gray", dash="dash", ), ), hline_shape_kwargs, ) ) if fit_sim_range: fig = self.fit_fig_to_sim_range( fig, column=column, sim_start=sim_start, sim_end=sim_end, group_by=False, xref=xref, ) return fig @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `ReturnsAccessor.plots`. Merges `vectorbtpro.generic.accessors.GenericAccessor.plots_defaults`, defaults from `ReturnsAccessor.defaults` (acting as `settings`), and `plots` from `vectorbtpro._settings.returns`""" from vectorbtpro._settings import settings returns_plots_cfg = settings["returns"]["plots"] return merge_dicts( GenericAccessor.plots_defaults.__get__(self), dict(settings=self.defaults), dict(settings=dict(year_freq=self.year_freq)), returns_plots_cfg, ) _subplots: tp.ClassVar[Config] = HybridConfig( dict( plot_cumulative=dict( title="Cumulative Returns", yaxis_kwargs=dict(title="Cumulative returns"), plot_func="plot_cumulative", pass_hline_shape_kwargs=True, pass_add_trace_kwargs=True, pass_xref=True, pass_yref=True, tags="returns", ) ) ) @property def subplots(self) -> Config: return self._subplots ReturnsAccessor.override_metrics_doc(__pdoc__) ReturnsAccessor.override_subplots_doc(__pdoc__) @register_sr_vbt_accessor("returns") class ReturnsSRAccessor(ReturnsAccessor, GenericSRAccessor): """Accessor on top of return series. For Series only. Accessible via `pd.Series.vbt.returns`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, sim_start: tp.Optional[tp.Array1d] = None, sim_end: tp.Optional[tp.Array1d] = None, _full_init: bool = True, **kwargs, ) -> None: GenericSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: ReturnsAccessor.__init__( self, wrapper, obj=obj, bm_returns=bm_returns, year_freq=year_freq, defaults=defaults, sim_start=sim_start, sim_end=sim_end, **kwargs, ) @register_df_vbt_accessor("returns") class ReturnsDFAccessor(ReturnsAccessor, GenericDFAccessor): """Accessor on top of return series. For DataFrames only. Accessible via `pd.DataFrame.vbt.returns`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, bm_returns: tp.Optional[tp.ArrayLike] = None, year_freq: tp.Optional[tp.FrequencyLike] = None, defaults: tp.KwargsLike = None, sim_start: tp.Optional[tp.Array1d] = None, sim_end: tp.Optional[tp.Array1d] = None, _full_init: bool = True, **kwargs, ) -> None: GenericDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: ReturnsAccessor.__init__( self, wrapper, obj=obj, bm_returns=bm_returns, year_freq=year_freq, defaults=defaults, sim_start=sim_start, sim_end=sim_end, **kwargs, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for returns.""" from vectorbtpro import _typing as tp __pdoc__all__ = __all__ = [ "RollSharpeAIS", "RollSharpeAOS", ] __pdoc__ = {} # ############# States ############# # class RollSharpeAIS(tp.NamedTuple): i: int ret: float pre_window_ret: float cumsum: float cumsum_sq: float nancnt: int window: int minp: tp.Optional[int] ddof: int ann_factor: float __pdoc__[ "RollSharpeAIS" ] = """A named tuple representing the input state of `vectorbtpro.returns.nb.rolling_sharpe_ratio_acc_nb`.""" class RollSharpeAOS(tp.NamedTuple): cumsum: float cumsum_sq: float nancnt: int value: float __pdoc__[ "RollSharpeAOS" ] = """A named tuple representing the output state of `vectorbtpro.returns.nb.rolling_sharpe_ratio_acc_nb`.""" # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for returns. Provides an arsenal of Numba-compiled functions that are used by accessors and for measuring portfolio performance. These only accept NumPy arrays and other Numba-compatible types. !!! note vectorbt treats matrices as first-class citizens and expects input arrays to be 2-dim, unless function has suffix `_1d` or is meant to be input to another function. Data is processed along index (axis 0). All functions passed as argument must be Numba-compiled.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro._settings import settings from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb from vectorbtpro.base.reshaping import to_1d_array_nb from vectorbtpro.generic import nb as generic_nb, enums as generic_enums from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.returns.enums import RollSharpeAIS, RollSharpeAOS from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.math_ import add_nb __all__ = [] _inf_to_nan = settings["returns"]["inf_to_nan"] _nan_to_zero = settings["returns"]["nan_to_zero"] # ############# Metrics ############# # @register_jitted(cache=True) def get_return_nb( input_value: float, output_value: float, log_returns: bool = False, inf_to_nan: bool = _inf_to_nan, nan_to_zero: bool = _nan_to_zero, ) -> float: """Calculate return from input and output value.""" if input_value == 0: if output_value == 0: r = 0.0 else: r = np.inf * np.sign(output_value) else: return_value = add_nb(output_value, -input_value) / input_value if log_returns: r = np.log1p(return_value) else: r = return_value if inf_to_nan and np.isinf(r): r = np.nan if nan_to_zero and np.isnan(r): r = 0.0 return r @register_jitted(cache=True) def returns_1d_nb( arr: tp.Array1d, init_value: float = np.nan, log_returns: bool = False, ) -> tp.Array1d: """Calculate returns.""" out = np.empty(arr.shape, dtype=float_) if np.isnan(init_value) and arr.shape[0] > 0: input_value = arr[0] else: input_value = init_value for i in range(arr.shape[0]): output_value = arr[i] out[i] = get_return_nb(input_value, output_value, log_returns=log_returns) input_value = output_value return out @register_chunkable( size=ch.ArraySizer(arg_query="arr", axis=1), arg_take_spec=dict( arr=ch.ArraySlicer(axis=1), init_value=base_ch.FlexArraySlicer(), log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def returns_nb( arr: tp.Array2d, init_value: tp.FlexArray1dLike = np.nan, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """2-dim version of `returns_1d_nb`.""" init_value_ = to_1d_array_nb(np.asarray(init_value)) out = np.full(arr.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=arr.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(arr.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _init_value = flex_select_1d_pc_nb(init_value_, col) out[_sim_start:_sim_end, col] = returns_1d_nb( arr[_sim_start:_sim_end, col], init_value=_init_value, log_returns=log_returns, ) return out @register_jitted(cache=True) def mirror_returns_1d_nb(returns: tp.Array1d, log_returns: bool = False) -> tp.Array1d: """Calculate mirrored returns. A mirrored return is an inverse, or negative return. For log returns, it negates each return. For simple returns, it uses the formula $\frac{1}{1 + R_t} - 1$.""" out = np.empty(returns.shape, dtype=float_) for i in range(returns.shape[0]): if log_returns: out[i] = -returns[i] else: if returns[i] <= -1: out[i] = np.inf else: out[i] = (1 / (1 + returns[i])) - 1 return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def mirror_returns_nb( returns: tp.Array2d, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """2-dim version of `mirror_returns_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = mirror_returns_1d_nb( returns[_sim_start:_sim_end, col], log_returns=log_returns, ) return out @register_jitted(cache=True) def cumulative_returns_1d_nb( returns: tp.Array1d, start_value: float = 1.0, log_returns: bool = False, ) -> tp.Array1d: """Cumulative returns.""" out = np.empty_like(returns, dtype=float_) if log_returns: cumsum = 0 for i in range(returns.shape[0]): if not np.isnan(returns[i]): cumsum += returns[i] if start_value == 0: out[i] = cumsum else: out[i] = np.exp(cumsum) * start_value else: cumprod = 1 for i in range(returns.shape[0]): if not np.isnan(returns[i]): cumprod *= 1 + returns[i] if start_value == 0: out[i] = cumprod - 1 else: out[i] = cumprod * start_value return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), start_value=None, log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def cumulative_returns_nb( returns: tp.Array2d, start_value: float = 1.0, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """2-dim version of `cumulative_returns_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = cumulative_returns_1d_nb( returns[_sim_start:_sim_end, col], start_value=start_value, log_returns=log_returns, ) return out @register_jitted(cache=True) def final_value_1d_nb( returns: tp.Array1d, start_value: float = 1.0, log_returns: bool = False, ) -> float: """Final value.""" if log_returns: cumsum = 0 for i in range(returns.shape[0]): if not np.isnan(returns[i]): cumsum += returns[i] if start_value == 0: return cumsum return np.exp(cumsum) * start_value else: cumprod = 1 for i in range(returns.shape[0]): if not np.isnan(returns[i]): cumprod *= 1 + returns[i] if start_value == 0: return cumprod - 1 return cumprod * start_value @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), start_value=None, log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def final_value_nb( returns: tp.Array2d, start_value: float = 1.0, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `final_value_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = final_value_1d_nb( returns[_sim_start:_sim_end, col], start_value=start_value, log_returns=log_returns, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, start_value=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_final_value_nb( returns: tp.Array2d, window: int, start_value: float = 1.0, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `final_value_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, final_value_1d_nb, start_value, log_returns, ) return out @register_jitted(cache=True) def total_return_1d_nb(returns: tp.Array1d, log_returns: bool = False) -> float: """Total return.""" return final_value_1d_nb(returns, start_value=0.0, log_returns=log_returns) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def total_return_nb( returns: tp.Array2d, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `total_return_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = total_return_1d_nb(returns[_sim_start:_sim_end, col], log_returns=log_returns) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_total_return_nb( returns: tp.Array2d, window: int, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `total_return_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, total_return_1d_nb, log_returns, ) return out @register_jitted(cache=True) def annualized_return_1d_nb( returns: tp.Array1d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[float] = None, ) -> float: """Annualized total return. This is equivalent to the compound annual growth rate (CAGR).""" if periods is None: periods = returns.shape[0] final_value = final_value_1d_nb(returns, log_returns=log_returns) if periods == 0: return np.nan return final_value ** (ann_factor / periods) - 1 @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ann_factor=None, log_returns=None, periods=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def annualized_return_nb( returns: tp.Array2d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `annualized_return_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) if periods is None: period_ = sim_end_ - sim_start_ else: period_ = to_1d_array_nb(np.asarray(periods).astype(int_)) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _period = flex_select_1d_pc_nb(period_, col) out[col] = annualized_return_1d_nb( returns[_sim_start:_sim_end, col], ann_factor, log_returns=log_returns, periods=_period, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_annualized_return_nb( returns: tp.Array2d, window: int, ann_factor: float, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `annualized_return_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, annualized_return_1d_nb, ann_factor, log_returns, ) return out @register_jitted(cache=True) def annualized_volatility_1d_nb( returns: tp.Array1d, ann_factor: float, levy_alpha: float = 2.0, ddof: int = 0, ) -> float: """Annualized volatility of a strategy.""" return generic_nb.nanstd_1d_nb(returns, ddof=ddof) * ann_factor ** (1.0 / levy_alpha) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ann_factor=None, levy_alpha=None, ddof=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def annualized_volatility_nb( returns: tp.Array2d, ann_factor: float, levy_alpha: float = 2.0, ddof: int = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `annualized_volatility_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = annualized_volatility_1d_nb( returns[_sim_start:_sim_end, col], ann_factor, levy_alpha=levy_alpha, ddof=ddof, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, levy_alpha=None, ddof=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_annualized_volatility_nb( returns: tp.Array2d, window: int, ann_factor: float, levy_alpha: float = 2.0, ddof: int = 0, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `annualized_volatility_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, annualized_volatility_1d_nb, ann_factor, levy_alpha, ddof, ) return out @register_jitted(cache=True) def max_drawdown_1d_nb(returns: tp.Array1d, log_returns: bool = False) -> float: """Total maximum drawdown (MDD).""" cum_ret = np.nan value_max = 1.0 out = 0.0 for i in range(returns.shape[0]): if not np.isnan(returns[i]): if np.isnan(cum_ret): cum_ret = 1.0 if log_returns: ret = np.exp(returns[i]) - 1 else: ret = returns[i] cum_ret *= ret + 1.0 if cum_ret > value_max: value_max = cum_ret elif cum_ret < value_max: dd = cum_ret / value_max - 1 if dd < out: out = dd if np.isnan(cum_ret): return np.nan return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), log_returns=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def max_drawdown_nb( returns: tp.Array2d, log_returns: bool = False, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `max_drawdown_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = max_drawdown_1d_nb(returns[_sim_start:_sim_end, col], log_returns=log_returns) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_max_drawdown_nb( returns: tp.Array2d, window: int, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `max_drawdown_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, max_drawdown_1d_nb, log_returns, ) return out @register_jitted(cache=True) def calmar_ratio_1d_nb( returns: tp.Array1d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[float] = None, ) -> float: """Calmar ratio, or drawdown ratio, of a strategy.""" max_drawdown = max_drawdown_1d_nb(returns, log_returns=log_returns) if max_drawdown == 0: return np.nan annualized_return = annualized_return_1d_nb( returns, ann_factor, log_returns=log_returns, periods=periods, ) if max_drawdown == 0: if annualized_return == 0: return np.nan return np.inf return annualized_return / np.abs(max_drawdown) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ann_factor=None, log_returns=None, periods=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def calmar_ratio_nb( returns: tp.Array2d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `calmar_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) if periods is None: period_ = sim_end_ - sim_start_ else: period_ = to_1d_array_nb(np.asarray(periods).astype(int_)) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _period = flex_select_1d_pc_nb(period_, col) out[col] = calmar_ratio_1d_nb( returns[_sim_start:_sim_end, col], ann_factor, log_returns=log_returns, periods=_period, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_calmar_ratio_nb( returns: tp.Array2d, window: int, ann_factor: float, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `calmar_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, calmar_ratio_1d_nb, ann_factor, log_returns, ) return out @register_jitted(cache=True) def deannualized_return_nb(ret: float, ann_factor: float) -> float: """Deannualized return.""" if ann_factor == 1: return ret if ann_factor <= -1: return np.nan return (1 + ret) ** (1.0 / ann_factor) - 1 @register_jitted(cache=True) def omega_ratio_1d_nb(returns: tp.Array1d) -> float: """Omega ratio of a strategy.""" numer = 0.0 denom = 0.0 for i in range(returns.shape[0]): ret = returns[i] if ret > 0: numer += ret elif ret < 0: denom -= ret if denom == 0: if numer == 0: return np.nan return np.inf return numer / denom @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def omega_ratio_nb( returns: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `omega_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = omega_ratio_1d_nb(returns[_sim_start:_sim_end, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_omega_ratio_nb( returns: tp.Array2d, window: int, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `omega_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, omega_ratio_1d_nb, ) return out @register_jitted(cache=True) def sharpe_ratio_1d_nb( returns: tp.Array1d, ann_factor: float, ddof: int = 0, ) -> float: """Sharpe ratio of a strategy.""" mean = np.nanmean(returns) std = generic_nb.nanstd_1d_nb(returns, ddof=ddof) if std == 0: if mean == 0: return np.nan return np.inf return mean / std * np.sqrt(ann_factor) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ann_factor=None, ddof=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def sharpe_ratio_nb( returns: tp.Array2d, ann_factor: float, ddof: int = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `sharpe_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = sharpe_ratio_1d_nb( returns[_sim_start:_sim_end, col], ann_factor, ddof=ddof, ) return out @register_jitted(cache=True) def rolling_sharpe_ratio_acc_nb(in_state: RollSharpeAIS) -> RollSharpeAOS: """Accumulator of `rolling_sharpe_ratio_stream_nb`. Takes a state of type `vectorbtpro.returns.enums.RollSharpeAIS` and returns a state of type `vectorbtpro.returns.enums.RollSharpeAOS`.""" mean_in_state = generic_enums.RollMeanAIS( i=in_state.i, value=in_state.ret, pre_window_value=in_state.pre_window_ret, cumsum=in_state.cumsum, nancnt=in_state.nancnt, window=in_state.window, minp=in_state.minp, ) mean_out_state = generic_nb.rolling_mean_acc_nb(mean_in_state) std_in_state = generic_enums.RollStdAIS( i=in_state.i, value=in_state.ret, pre_window_value=in_state.pre_window_ret, cumsum=in_state.cumsum, cumsum_sq=in_state.cumsum_sq, nancnt=in_state.nancnt, window=in_state.window, minp=in_state.minp, ddof=in_state.ddof, ) std_out_state = generic_nb.rolling_std_acc_nb(std_in_state) mean = mean_out_state.value std = std_out_state.value if std == 0: sharpe = np.nan else: sharpe = mean / std * np.sqrt(in_state.ann_factor) return RollSharpeAOS( cumsum=std_out_state.cumsum, cumsum_sq=std_out_state.cumsum_sq, nancnt=std_out_state.nancnt, value=sharpe, ) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, ddof=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def rolling_sharpe_ratio_stream_nb( returns: tp.Array2d, window: int, ann_factor: float, ddof: int = 0, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling Sharpe ratio in a streaming fashion. Uses `rolling_sharpe_ratio_acc_nb` at each iteration.""" if window is None: window = returns.shape[0] if minp is None: minp = window out = np.full(returns.shape, np.nan, dtype=float_) if returns.shape[0] == 0: return out sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue cumsum = 0.0 cumsum_sq = 0.0 nancnt = 0 for i in range(_sim_start, _sim_end): in_state = RollSharpeAIS( i=i - _sim_start, ret=returns[i, col], pre_window_ret=returns[i - window, col] if i - window >= 0 else np.nan, cumsum=cumsum, cumsum_sq=cumsum_sq, nancnt=nancnt, window=window, minp=minp, ddof=ddof, ann_factor=ann_factor, ) out_state = rolling_sharpe_ratio_acc_nb(in_state) cumsum = out_state.cumsum cumsum_sq = out_state.cumsum_sq nancnt = out_state.nancnt out[i, col] = out_state.value return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, ddof=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), stream_mode=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_sharpe_ratio_nb( returns: tp.Array2d, window: int, ann_factor: float, ddof: int = 0, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, stream_mode: bool = True, ) -> tp.Array2d: """Rolling version of `sharpe_ratio_1d_nb`.""" if stream_mode: return rolling_sharpe_ratio_stream_nb( returns, window, ann_factor, minp=minp, ddof=ddof, sim_start=sim_start, sim_end=sim_end, ) out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, sharpe_ratio_1d_nb, ann_factor, ddof, ) return out @register_jitted(cache=True) def downside_risk_1d_nb(returns: tp.Array1d, ann_factor: float) -> float: """Downside deviation below a threshold.""" cnt = 0 adj_ret_sqrd_sum = np.nan for i in range(returns.shape[0]): if not np.isnan(returns[i]): cnt += 1 if np.isnan(adj_ret_sqrd_sum): adj_ret_sqrd_sum = 0.0 if returns[i] <= 0: adj_ret_sqrd_sum += returns[i] ** 2 if cnt == 0: return np.nan adj_ret_sqrd_mean = adj_ret_sqrd_sum / cnt return np.sqrt(adj_ret_sqrd_mean) * np.sqrt(ann_factor) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ann_factor=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def downside_risk_nb( returns: tp.Array2d, ann_factor: float, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `downside_risk_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = downside_risk_1d_nb(returns[_sim_start:_sim_end, col], ann_factor) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_downside_risk_nb( returns: tp.Array2d, window: int, ann_factor: float, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `downside_risk_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, downside_risk_1d_nb, ann_factor, ) return out @register_jitted(cache=True) def sortino_ratio_1d_nb(returns: tp.Array1d, ann_factor: float) -> float: """Sortino ratio of a strategy.""" avg_annualized_return = np.nanmean(returns) * ann_factor downside_risk = downside_risk_1d_nb(returns, ann_factor) if downside_risk == 0: if avg_annualized_return == 0: return np.nan return np.inf return avg_annualized_return / downside_risk @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ann_factor=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def sortino_ratio_nb( returns: tp.Array2d, ann_factor: float, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `sortino_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = sortino_ratio_1d_nb(returns[_sim_start:_sim_end, col], ann_factor) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_sortino_ratio_nb( returns: tp.Array2d, window: int, ann_factor: float, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `sortino_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, sortino_ratio_1d_nb, ann_factor, ) return out @register_jitted(cache=True) def information_ratio_1d_nb(returns: tp.Array1d, ddof: int = 0) -> float: """Information ratio of a strategy.""" mean = np.nanmean(returns) std = generic_nb.nanstd_1d_nb(returns, ddof=ddof) if std == 0: if mean == 0: return np.nan return np.inf return mean / std @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), ddof=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def information_ratio_nb( returns: tp.Array2d, ddof: int = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `information_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = information_ratio_1d_nb(returns[_sim_start:_sim_end, col], ddof=ddof) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, ddof=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_information_ratio_nb( returns: tp.Array2d, window: int, ddof: int = 0, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `information_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, information_ratio_1d_nb, ddof, ) return out @register_jitted(cache=True) def beta_1d_nb( returns: tp.Array1d, bm_returns: tp.Array1d, ddof: int = 0, ) -> float: """Beta.""" cov = generic_nb.nancov_1d_nb(returns, bm_returns, ddof=ddof) var = generic_nb.nanvar_1d_nb(bm_returns, ddof=ddof) if var == 0: if cov == 0: return np.nan return np.inf return cov / var @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), ddof=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def beta_nb( returns: tp.Array2d, bm_returns: tp.Array2d, ddof: int = 0, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `beta_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = beta_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], ddof=ddof, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), window=None, ddof=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_beta_nb( returns: tp.Array2d, bm_returns: tp.Array2d, window: int, ddof: int = 0, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `beta_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], window, minp, beta_1d_nb, ddof, ) return out @register_jitted(cache=True) def alpha_1d_nb( returns: tp.Array1d, bm_returns: tp.Array1d, ann_factor: float, ) -> float: """Annualized alpha.""" beta = beta_1d_nb(returns, bm_returns) return (np.nanmean(returns) - beta * np.nanmean(bm_returns) + 1) ** ann_factor - 1 @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), ann_factor=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def alpha_nb( returns: tp.Array2d, bm_returns: tp.Array2d, ann_factor: float, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `alpha_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = alpha_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], ann_factor, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_alpha_nb( returns: tp.Array2d, bm_returns: tp.Array2d, window: int, ann_factor: float, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `alpha_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], window, minp, alpha_1d_nb, ann_factor, ) return out @register_jitted(cache=True) def tail_ratio_1d_nb(returns: tp.Array1d) -> float: """Ratio between the right (95%) and left tail (5%).""" perc_95 = np.abs(np.nanpercentile(returns, 95)) perc_5 = np.abs(np.nanpercentile(returns, 5)) if perc_5 == 0: if perc_95 == 0: return np.nan return np.inf return perc_95 / perc_5 @register_jitted(cache=True) def tail_ratio_noarr_1d_nb(returns: tp.Array1d) -> float: """`tail_ratio_1d_nb` that does not allocate any arrays.""" perc_95 = np.abs(generic_nb.nanpercentile_noarr_1d_nb(returns, 95)) perc_5 = np.abs(generic_nb.nanpercentile_noarr_1d_nb(returns, 5)) if perc_5 == 0: if perc_95 == 0: return np.nan return np.inf return perc_95 / perc_5 @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), noarr_mode=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def tail_ratio_nb( returns: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, noarr_mode: bool = True, ) -> tp.Array1d: """2-dim version of `tail_ratio_1d_nb` and `tail_ratio_noarr_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if noarr_mode: out[col] = tail_ratio_noarr_1d_nb(returns[_sim_start:_sim_end, col]) else: out[col] = tail_ratio_1d_nb(returns[_sim_start:_sim_end, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), noarr_mode=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_tail_ratio_nb( returns: tp.Array2d, window: int, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, noarr_mode: bool = True, ) -> tp.Array2d: """Rolling version of `tail_ratio_1d_nb` and `tail_ratio_noarr_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if noarr_mode: out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, tail_ratio_noarr_1d_nb, ) else: out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, tail_ratio_1d_nb, ) return out @register_jitted(cache=True) def profit_factor_1d_nb(returns: tp.Array1d) -> float: """Profit factor.""" numer = 0 denom = 0 for i in range(returns.shape[0]): if not np.isnan(returns[i]): if returns[i] > 0: numer += returns[i] elif returns[i] < 0: denom += abs(returns[i]) if denom == 0: if numer == 0: return np.nan return np.inf return numer / denom @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def profit_factor_nb( returns: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `profit_factor_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = profit_factor_1d_nb(returns[_sim_start:_sim_end, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_profit_factor_nb( returns: tp.Array2d, window: int, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `profit_factor_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, profit_factor_1d_nb, ) return out @register_jitted(cache=True) def common_sense_ratio_1d_nb(returns: tp.Array1d) -> float: """Common Sense Ratio.""" tail_ratio = tail_ratio_1d_nb(returns) profit_factor = profit_factor_1d_nb(returns) return tail_ratio * profit_factor @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def common_sense_ratio_nb( returns: tp.Array2d, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `common_sense_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[col] = common_sense_ratio_1d_nb(returns[_sim_start:_sim_end, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_common_sense_ratio_nb( returns: tp.Array2d, window: int, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `common_sense_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, common_sense_ratio_1d_nb, ) return out @register_jitted(cache=True) def value_at_risk_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float: """Value at risk (VaR) of a returns stream.""" return np.nanpercentile(returns, 100 * cutoff) @register_jitted(cache=True) def value_at_risk_noarr_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float: """`value_at_risk_1d_nb` that does not allocate any arrays.""" return generic_nb.nanpercentile_noarr_1d_nb(returns, 100 * cutoff) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), cutoff=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), noarr_mode=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def value_at_risk_nb( returns: tp.Array2d, cutoff: float = 0.05, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, noarr_mode: bool = True, ) -> tp.Array1d: """2-dim version of `value_at_risk_1d_nb` and `value_at_risk_noarr_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if noarr_mode: out[col] = value_at_risk_noarr_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff) else: out[col] = value_at_risk_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, cutoff=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), noarr_mode=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_value_at_risk_nb( returns: tp.Array2d, window: int, cutoff: float = 0.05, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, noarr_mode: bool = True, ) -> tp.Array2d: """Rolling version of `value_at_risk_1d_nb` and `value_at_risk_noarr_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if noarr_mode: out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, value_at_risk_noarr_1d_nb, cutoff, ) else: out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, value_at_risk_1d_nb, cutoff, ) return out @register_jitted(cache=True) def cond_value_at_risk_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float: """Conditional value at risk (CVaR) of a returns stream.""" cutoff_index = int((len(returns) - 1) * cutoff) return np.mean(np.partition(returns, cutoff_index)[: cutoff_index + 1]) @register_jitted(cache=True) def cond_value_at_risk_noarr_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float: """`cond_value_at_risk_1d_nb` that does not allocate any arrays.""" return generic_nb.nanpartition_mean_noarr_1d_nb(returns, cutoff * 100) @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), cutoff=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), noarr_mode=None, ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def cond_value_at_risk_nb( returns: tp.Array2d, cutoff: float = 0.05, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, noarr_mode: bool = True, ) -> tp.Array1d: """2-dim version of `cond_value_at_risk_1d_nb` and `cond_value_at_risk_noarr_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if noarr_mode: out[col] = cond_value_at_risk_noarr_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff) else: out[col] = cond_value_at_risk_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), window=None, cutoff=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), noarr_mode=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_cond_value_at_risk_nb( returns: tp.Array2d, window: int, cutoff: float = 0.05, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, noarr_mode: bool = True, ) -> tp.Array2d: """Rolling version of `cond_value_at_risk_1d_nb` and `cond_value_at_risk_noarr_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue if noarr_mode: out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, cond_value_at_risk_noarr_1d_nb, cutoff, ) else: out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb( returns[_sim_start:_sim_end, col], window, minp, cond_value_at_risk_1d_nb, cutoff, ) return out @register_jitted(cache=True) def capture_ratio_1d_nb( returns: tp.Array1d, bm_returns: tp.Array1d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[float] = None, ) -> float: """Capture ratio.""" annualized_return1 = annualized_return_1d_nb( returns, ann_factor, log_returns=log_returns, periods=periods, ) annualized_return2 = annualized_return_1d_nb( bm_returns, ann_factor, log_returns=log_returns, periods=periods, ) if annualized_return2 == 0: if annualized_return1 == 0: return np.nan return np.inf return annualized_return1 / annualized_return2 @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), ann_factor=None, log_returns=None, periods=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def capture_ratio_nb( returns: tp.Array2d, bm_returns: tp.Array2d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `capture_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) if periods is None: period_ = sim_end_ - sim_start_ else: period_ = to_1d_array_nb(np.asarray(periods).astype(int_)) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _period = flex_select_1d_pc_nb(period_, col) out[col] = capture_ratio_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], ann_factor, log_returns=log_returns, periods=_period, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_capture_ratio_nb( returns: tp.Array2d, bm_returns: tp.Array2d, window: int, ann_factor: float, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `capture_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], window, minp, capture_ratio_1d_nb, ann_factor, log_returns, ) return out @register_jitted(cache=True) def up_capture_ratio_1d_nb( returns: tp.Array1d, bm_returns: tp.Array1d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[float] = None, ) -> float: """Capture ratio for periods when the benchmark return is positive.""" if periods is None: periods = returns.shape[0] def _annualized_pos_return(a): ann_ret = np.nan ret_cnt = 0 for i in range(a.shape[0]): if not np.isnan(a[i]): if log_returns: _a = np.exp(a[i]) - 1 else: _a = a[i] if np.isnan(ann_ret): ann_ret = 1.0 if _a > 0: ann_ret *= _a + 1.0 ret_cnt += 1 if ret_cnt == 0: return np.nan if periods == 0: return np.nan return ann_ret ** (ann_factor / periods) - 1 annualized_return = _annualized_pos_return(returns) annualized_bm_return = _annualized_pos_return(bm_returns) if annualized_bm_return == 0: if annualized_return == 0: return np.nan return np.inf return annualized_return / annualized_bm_return @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), ann_factor=None, log_returns=None, periods=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def up_capture_ratio_nb( returns: tp.Array2d, bm_returns: tp.Array2d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `up_capture_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) if periods is None: period_ = sim_end_ - sim_start_ else: period_ = to_1d_array_nb(np.asarray(periods).astype(int_)) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _period = flex_select_1d_pc_nb(period_, col) out[col] = up_capture_ratio_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], ann_factor, log_returns=log_returns, periods=_period, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_up_capture_ratio_nb( returns: tp.Array2d, bm_returns: tp.Array2d, window: int, ann_factor: float, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `up_capture_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], window, minp, up_capture_ratio_1d_nb, ann_factor, log_returns, ) return out @register_jitted(cache=True) def down_capture_ratio_1d_nb( returns: tp.Array1d, bm_returns: tp.Array1d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[float] = None, ) -> float: """Capture ratio for periods when the benchmark return is negative.""" if periods is None: periods = returns.shape[0] def _annualized_neg_return(a): ann_ret = np.nan ret_cnt = 0 for i in range(a.shape[0]): if not np.isnan(a[i]): if log_returns: _a = np.exp(a[i]) - 1 else: _a = a[i] if np.isnan(ann_ret): ann_ret = 1.0 if _a < 0: ann_ret *= _a + 1.0 ret_cnt += 1 if ret_cnt == 0: return np.nan if periods == 0: return np.nan return ann_ret ** (ann_factor / periods) - 1 annualized_return = _annualized_neg_return(returns) annualized_bm_return = _annualized_neg_return(bm_returns) if annualized_bm_return == 0: if annualized_return == 0: return np.nan return np.inf return annualized_return / annualized_bm_return @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), ann_factor=None, log_returns=None, periods=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def down_capture_ratio_nb( returns: tp.Array2d, bm_returns: tp.Array2d, ann_factor: float, log_returns: bool = False, periods: tp.Optional[tp.FlexArray1dLike] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array1d: """2-dim version of `down_capture_ratio_1d_nb`.""" out = np.full(returns.shape[1], np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) if periods is None: period_ = sim_end_ - sim_start_ else: period_ = to_1d_array_nb(np.asarray(periods).astype(int_)) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue _period = flex_select_1d_pc_nb(period_, col) out[col] = down_capture_ratio_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], ann_factor, log_returns=log_returns, periods=_period, ) return out @register_chunkable( size=ch.ArraySizer(arg_query="returns", axis=1), arg_take_spec=dict( returns=ch.ArraySlicer(axis=1), bm_returns=ch.ArraySlicer(axis=1), window=None, ann_factor=None, log_returns=None, minp=None, sim_start=base_ch.FlexArraySlicer(), sim_end=base_ch.FlexArraySlicer(), ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rolling_down_capture_ratio_nb( returns: tp.Array2d, bm_returns: tp.Array2d, window: int, ann_factor: float, log_returns: bool = False, minp: tp.Optional[int] = None, sim_start: tp.Optional[tp.FlexArray1dLike] = None, sim_end: tp.Optional[tp.FlexArray1dLike] = None, ) -> tp.Array2d: """Rolling version of `down_capture_ratio_1d_nb`.""" out = np.full(returns.shape, np.nan, dtype=float_) sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb( sim_shape=returns.shape, sim_start=sim_start, sim_end=sim_end, ) for col in prange(returns.shape[1]): _sim_start = sim_start_[col] _sim_end = sim_end_[col] if _sim_start >= _sim_end: continue out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb( returns[_sim_start:_sim_end, col], bm_returns[_sim_start:_sim_end, col], window, minp, down_capture_ratio_1d_nb, ann_factor, log_returns, ) return out # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Adapter class for QuantStats. !!! note Accessors do not utilize caching. We can access the adapter from `ReturnsAccessor`: ```pycon >>> from vectorbtpro import * >>> import quantstats as qs >>> np.random.seed(42) >>> rets = pd.Series(np.random.uniform(-0.1, 0.1, size=(100,))) >>> bm_returns = pd.Series(np.random.uniform(-0.1, 0.1, size=(100,))) >>> rets.vbt.returns.qs.r_squared(benchmark=bm_returns) 0.0011582111228735541 ``` Which is the same as: ```pycon >>> qs.stats.r_squared(rets, bm_returns) ``` So why not just using `qs.stats`? First, we can define all parameters such as benchmark returns once and avoid passing them repeatedly to every function. Second, vectorbt automatically translates parameters passed to `ReturnsAccessor` for the use in quantstats. ```pycon >>> # Defaults that vectorbt understands >>> ret_acc = rets.vbt.returns( ... bm_returns=bm_returns, ... freq='d', ... year_freq='365d', ... defaults=dict(risk_free=0.001) ... ) >>> ret_acc.qs.r_squared() 0.0011582111228735541 >>> ret_acc.qs.sharpe() -1.9158923252075455 >>> # Defaults that only quantstats understands >>> qs_defaults = dict( ... benchmark=bm_returns, ... periods=365, ... rf=0.001 ... ) >>> ret_acc_qs = rets.vbt.returns.qs(defaults=qs_defaults) >>> ret_acc_qs.r_squared() 0.0011582111228735541 >>> ret_acc_qs.sharpe() -1.9158923252075455 ``` The adapter automatically passes the returns to the particular function. It also merges the defaults defined in the settings, the defaults passed to `ReturnsAccessor`, and the defaults passed to `QSAdapter` itself, and matches them with the argument names listed in the function's signature. For example, the `periods` argument defaults to the annualization factor `ReturnsAccessor.ann_factor`, which itself is based on the `freq` argument. This makes the results produced by quantstats and vectorbt at least somewhat similar. ```pycon >>> vbt.settings.wrapping['freq'] = 'h' >>> vbt.settings.returns['year_freq'] = '365d' >>> rets.vbt.returns.sharpe_ratio() # ReturnsAccessor -9.38160953971508 >>> rets.vbt.returns.qs.sharpe() # quantstats via QSAdapter -9.38160953971508 ``` We can still override any argument by overriding its default or by passing it directly to the function: ```pycon >>> rets.vbt.returns.qs(defaults=dict(periods=252)).sharpe() -1.5912029345745982 >>> rets.vbt.returns.qs.sharpe(periods=252) -1.5912029345745982 >>> qs.stats.sharpe(rets) -1.5912029345745982 ``` """ from vectorbtpro.utils.module_ import assert_can_import assert_can_import("quantstats") from inspect import getmembers, isfunction, signature, Parameter import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.returns.accessors import ReturnsAccessor from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts, Configured from vectorbtpro.utils.parsing import get_func_arg_names, has_variable_kwargs __all__ = [ "QSAdapter", ] def attach_qs_methods(cls: tp.Type[tp.T], replace_signature: bool = True) -> tp.Type[tp.T]: """Class decorator to attach quantstats methods.""" import quantstats as qs checks.assert_subclass_of(cls, "QSAdapter") for module_name in ["utils", "stats", "plots", "reports"]: for qs_func_name, qs_func in getmembers(getattr(qs, module_name), isfunction): if not qs_func_name.startswith("_") and checks.func_accepts_arg(qs_func, "returns"): if module_name == "plots": new_method_name = "plot_" + qs_func_name elif module_name == "reports": new_method_name = qs_func_name + "_report" else: new_method_name = qs_func_name def new_method( self, *, _func: tp.Callable = qs_func, column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Any: func_arg_names = get_func_arg_names(_func) has_var_kwargs = has_variable_kwargs(_func) defaults = self.defaults if has_var_kwargs: pass_kwargs = dict(kwargs) else: pass_kwargs = {} for arg_name in func_arg_names: if arg_name not in kwargs: if arg_name in defaults: pass_kwargs[arg_name] = defaults[arg_name] elif arg_name == "benchmark": if self.returns_acc.bm_returns is not None: pass_kwargs["benchmark"] = self.returns_acc.bm_returns elif arg_name == "periods": pass_kwargs["periods"] = int(self.returns_acc.ann_factor) elif arg_name == "periods_per_year": pass_kwargs["periods_per_year"] = int(self.returns_acc.ann_factor) elif not has_var_kwargs: pass_kwargs[arg_name] = kwargs[arg_name] returns = self.returns_acc.select_col_from_obj( self.returns_acc.obj, column=column, wrapper=self.returns_acc.wrapper.regroup(False), ) if returns.name is None: returns = returns.rename("Strategy") else: returns = returns.rename(str(returns.name)) null_mask = returns.isnull() if "benchmark" in pass_kwargs and pass_kwargs["benchmark"] is not None: benchmark = pass_kwargs["benchmark"] benchmark = self.returns_acc.select_col_from_obj( benchmark, column=column, wrapper=self.returns_acc.wrapper.regroup(False), ) if benchmark.name is None: benchmark = benchmark.rename("Benchmark") else: benchmark = benchmark.rename(str(benchmark.name)) bm_null_mask = benchmark.isnull() null_mask = null_mask | bm_null_mask benchmark = benchmark.loc[~null_mask] if isinstance(benchmark.index, pd.DatetimeIndex): if benchmark.index.tz is not None: benchmark = benchmark.tz_convert("utc") if benchmark.index.tz is not None: benchmark = benchmark.tz_localize(None) pass_kwargs["benchmark"] = benchmark returns = returns.loc[~null_mask] if isinstance(returns.index, pd.DatetimeIndex): if returns.index.tz is not None: returns = returns.tz_convert("utc") if returns.index.tz is not None: returns = returns.tz_localize(None) signature(_func).bind(returns=returns, **pass_kwargs) return _func(returns=returns, **pass_kwargs) if replace_signature: # Replace the function's signature with the original one source_sig = signature(qs_func) new_method_params = tuple(signature(new_method).parameters.values()) self_arg = new_method_params[0] column_arg = new_method_params[2] other_args = [ ( p.replace(kind=Parameter.KEYWORD_ONLY) if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) else p ) for p in list(source_sig.parameters.values())[1:] ] source_sig = source_sig.replace(parameters=(self_arg, column_arg) + tuple(other_args)) new_method.__signature__ = source_sig new_method.__name__ = new_method_name new_method.__module__ = cls.__module__ new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}" new_method.__doc__ = f"See `quantstats.{module_name}.{qs_func_name}`." setattr(cls, new_method_name, new_method) return cls QSAdapterT = tp.TypeVar("QSAdapterT", bound="QSAdapter") @attach_qs_methods class QSAdapter(Configured): """Adapter class for quantstats.""" def __init__(self, returns_acc: ReturnsAccessor, defaults: tp.KwargsLike = None, **kwargs) -> None: checks.assert_instance_of(returns_acc, ReturnsAccessor) self._returns_acc = returns_acc self._defaults = defaults Configured.__init__(self, returns_acc=returns_acc, defaults=defaults, **kwargs) def __call__(self: QSAdapterT, **kwargs) -> QSAdapterT: """Allows passing arguments to the initializer.""" return self.replace(**kwargs) @property def returns_acc(self) -> ReturnsAccessor: """Returns accessor.""" return self._returns_acc @property def defaults_mapping(self) -> tp.Dict: """Common argument names in quantstats mapped to `ReturnsAccessor.defaults`.""" return dict(rf="risk_free", rolling_period="window") @property def defaults(self) -> tp.Kwargs: """Defaults for `QSAdapter`. Merges `defaults` from `vectorbtpro._settings.qs_adapter`, `returns_acc.defaults` (with adapted naming), and `defaults` from `QSAdapter.__init__`.""" from vectorbtpro._settings import settings qs_adapter_defaults_cfg = settings["qs_adapter"]["defaults"] mapped_defaults = dict() for k, v in self.defaults_mapping.items(): if v in self.returns_acc.defaults: mapped_defaults[k] = self.returns_acc.defaults[v] return merge_dicts(qs_adapter_defaults_cfg, mapped_defaults, self._defaults) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom signal generators built with the signal factory. You can access all the indicators by `vbt.*`.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.signals.generators.ohlcstcx import * from vectorbtpro.signals.generators.ohlcstx import * from vectorbtpro.signals.generators.rand import * from vectorbtpro.signals.generators.randnx import * from vectorbtpro.signals.generators.randx import * from vectorbtpro.signals.generators.rprob import * from vectorbtpro.signals.generators.rprobcx import * from vectorbtpro.signals.generators.rprobnx import * from vectorbtpro.signals.generators.rprobx import * from vectorbtpro.signals.generators.stcx import * from vectorbtpro.signals.generators.stx import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `OHLCSTCX`.""" from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.generators.ohlcstx import ohlcstx_config, ohlcstx_func_config, _bind_ohlcstx_plot __all__ = [ "OHLCSTCX", ] __pdoc__ = {} OHLCSTCX = SignalFactory( **ohlcstx_config.merge_with( dict( class_name="OHLCSTCX", module_name=__name__, short_name="ohlcstcx", mode="chain", ) ), ).with_place_func( **ohlcstx_func_config, ) class _OHLCSTCX(OHLCSTCX): """Exit signal generator based on OHLC and stop values. Generates chain of `new_entries` and `exits` based on `entries` and `vectorbtpro.signals.nb.ohlc_stop_place_nb`. See `OHLCSTX` for notes on parameters.""" plot = _bind_ohlcstx_plot(OHLCSTCX, "new_entries") setattr(OHLCSTCX, "__doc__", _OHLCSTCX.__doc__) setattr(OHLCSTCX, "plot", _OHLCSTCX.plot) OHLCSTCX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `OHLCSTX`.""" import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.signals.enums import StopType from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import ohlc_stop_place_nb from vectorbtpro.utils.config import ReadonlyConfig, merge_dicts __all__ = [ "OHLCSTX", ] __pdoc__ = {} ohlcstx_config = ReadonlyConfig( dict( class_name="OHLCSTX", module_name=__name__, short_name="ohlcstx", mode="exits", input_names=["entry_price", "open", "high", "low", "close"], in_output_names=["stop_price", "stop_type"], param_names=["sl_stop", "tsl_th", "tsl_stop", "tp_stop", "reverse"], attr_settings=dict( stop_type=dict(dtype=StopType), ), ) ) """Factory config for `OHLCSTX`.""" ohlcstx_func_config = ReadonlyConfig( dict( exit_place_func_nb=ohlc_stop_place_nb, exit_settings=dict( pass_inputs=["entry_price", "open", "high", "low", "close"], pass_in_outputs=["stop_price", "stop_type"], pass_params=["sl_stop", "tsl_th", "tsl_stop", "tp_stop", "reverse"], pass_kwargs=["is_entry_open"], ), in_output_settings=dict( stop_price=dict(dtype=float_), stop_type=dict(dtype=int_), ), param_settings=dict( sl_stop=flex_elem_param_config, tsl_th=flex_elem_param_config, tsl_stop=flex_elem_param_config, tp_stop=flex_elem_param_config, reverse=flex_elem_param_config, ), open=np.nan, high=np.nan, low=np.nan, close=np.nan, stop_price=np.nan, stop_type=-1, sl_stop=np.nan, tsl_th=np.nan, tsl_stop=np.nan, tp_stop=np.nan, reverse=False, is_entry_open=False, ) ) """Exit function config for `OHLCSTX`.""" OHLCSTX = SignalFactory(**ohlcstx_config).with_place_func(**ohlcstx_func_config) def _bind_ohlcstx_plot(base_cls: type, entries_attr: str) -> tp.Callable: base_cls_plot = base_cls.plot def plot( self, column: tp.Optional[tp.Label] = None, ohlc_kwargs: tp.KwargsLike = None, entry_price_kwargs: tp.KwargsLike = None, entry_trace_kwargs: tp.KwargsLike = None, exit_trace_kwargs: tp.KwargsLike = None, add_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, _base_cls_plot: tp.Callable = base_cls_plot, **layout_kwargs, ) -> tp.BaseFigure: self_col = self.select_col(column=column, group_by=False) if ohlc_kwargs is None: ohlc_kwargs = {} if entry_price_kwargs is None: entry_price_kwargs = {} if add_trace_kwargs is None: add_trace_kwargs = {} open_any = not self_col.open.isnull().all() high_any = not self_col.high.isnull().all() low_any = not self_col.low.isnull().all() close_any = not self_col.close.isnull().all() if open_any and high_any and low_any and close_any: ohlc_df = pd.concat((self_col.open, self_col.high, self_col.low, self_col.close), axis=1) ohlc_df.columns = ["Open", "High", "Low", "Close"] ohlc_kwargs = merge_dicts(layout_kwargs, dict(ohlc_trace_kwargs=dict(opacity=0.5)), ohlc_kwargs) fig = ohlc_df.vbt.ohlcv.plot(fig=fig, **ohlc_kwargs) else: entry_price_kwargs = merge_dicts(layout_kwargs, entry_price_kwargs) fig = self_col.entry_price.rename("Entry price").vbt.lineplot(fig=fig, **entry_price_kwargs) _base_cls_plot( self_col, entry_y=self_col.entry_price, exit_y=self_col.stop_price, exit_types=self_col.stop_type_readable, entry_trace_kwargs=entry_trace_kwargs, exit_trace_kwargs=exit_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **layout_kwargs, ) return fig plot.__doc__ = """Plot OHLC, `{0}.{1}` and `{0}.exits`. Args: ohlc_kwargs (dict): Keyword arguments passed to `vectorbtpro.ohlcv.accessors.OHLCVDFAccessor.plot`. entry_trace_kwargs (dict): Keyword arguments passed to `vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_entries` for `{0}.{1}`. exit_trace_kwargs (dict): Keyword arguments passed to `vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_exits` for `{0}.exits`. fig (Figure or FigureWidget): Figure to add traces to. **layout_kwargs: Keyword arguments for layout.""".format( base_cls.__name__, entries_attr, ) if entries_attr == "entries": plot.__doc__ += """ Usage: ```pycon >>> ohlcstx.iloc[:, 0].plot().show() ``` ![](/assets/images/api/OHLCSTX.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/OHLCSTX.dark.svg#only-dark){: .iimg loading=lazy } """ return plot class _OHLCSTX(OHLCSTX): """Exit signal generator based on OHLC and stop values. Generates `exits` based on `entries` and `vectorbtpro.signals.nb.ohlc_stop_place_nb`. !!! hint All parameters can be either a single value (per frame) or a NumPy array (per row, column, or element). To generate multiple combinations, pass them as lists. !!! warning Searches for an exit after each entry. If two entries come one after another, no exit can be placed. Consider either cleaning up entry signals prior to passing, or using `OHLCSTCX`. Usage: Test each stop type: ```pycon >>> from vectorbtpro import * >>> entries = pd.Series([True, False, False, False, False, False]) >>> price = pd.DataFrame({ ... 'open': [10, 11, 12, 11, 10, 9], ... 'high': [11, 12, 13, 12, 11, 10], ... 'low': [9, 10, 11, 10, 9, 8], ... 'close': [10, 11, 12, 11, 10, 9] ... }) >>> ohlcstx = vbt.OHLCSTX.run( ... entries, ... price['open'], ... price['open'], ... price['high'], ... price['low'], ... price['close'], ... sl_stop=[0.1, np.nan, np.nan, np.nan], ... tsl_th=[np.nan, np.nan, 0.2, np.nan], ... tsl_stop=[np.nan, 0.1, 0.3, np.nan], ... tp_stop=[np.nan, np.nan, np.nan, 0.1], ... is_entry_open=True ... ) >>> ohlcstx.entries ohlcstx_sl_stop 0.1 NaN NaN NaN ohlcstx_tsl_th NaN NaN 0.2 NaN ohlcstx_tsl_stop NaN 0.1 0.3 NaN ohlcstx_tp_stop NaN NaN NaN 0.1 0 True True True True 1 False False False False 2 False False False False 3 False False False False 4 False False False False 5 False False False False >>> ohlcstx.exits ohlcstx_sl_stop 0.1 NaN NaN NaN ohlcstx_tsl_th NaN NaN 0.2 NaN ohlcstx_tsl_stop NaN 0.1 0.3 NaN ohlcstx_tp_stop NaN NaN NaN 0.1 0 False False False False 1 False False False True 2 False False False False 3 False True False False 4 True False True False 5 False False False False >>> ohlcstx.stop_price ohlcstx_sl_stop 0.1 NaN NaN NaN ohlcstx_tsl_th NaN NaN 0.2 NaN ohlcstx_tsl_stop NaN 0.1 0.3 NaN ohlcstx_tp_stop NaN NaN NaN 0.1 0 NaN NaN NaN NaN 1 NaN NaN NaN 11.0 2 NaN NaN NaN NaN 3 NaN 11.7 NaN NaN 4 9.0 NaN 9.1 NaN 5 NaN NaN NaN NaN >>> ohlcstx.stop_type_readable ohlcstx_sl_stop 0.1 NaN NaN NaN ohlcstx_tsl_th NaN NaN 0.2 NaN ohlcstx_tsl_stop NaN 0.1 0.3 NaN ohlcstx_tp_stop NaN NaN NaN 0.1 0 None None None None 1 None None None TP 2 None None None None 3 None TSL None None 4 SL None TTP None 5 None None None None ``` """ plot = _bind_ohlcstx_plot(OHLCSTX, "entries") setattr(OHLCSTX, "__doc__", _OHLCSTX.__doc__) setattr(OHLCSTX, "plot", _OHLCSTX.plot) OHLCSTX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RAND`.""" import numpy as np from vectorbtpro.indicators.configs import flex_col_param_config from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import rand_place_nb __all__ = [ "RAND", ] __pdoc__ = {} RAND = SignalFactory( class_name="RAND", module_name=__name__, short_name="rand", mode="entries", param_names=["n"], ).with_place_func( entry_place_func_nb=rand_place_nb, entry_settings=dict( pass_params=["n"], ), param_settings=dict( n=flex_col_param_config, ), seed=None, ) class _RAND(RAND): """Random entry signal generator based on the number of signals. Generates `entries` based on `vectorbtpro.signals.nb.rand_place_nb`. !!! hint Parameter `n` can be either a single value (per frame) or a NumPy array (per column). To generate multiple combinations, pass it as a list. Usage: Test three different entry counts values: ```pycon >>> from vectorbtpro import * >>> rand = vbt.RAND.run(input_shape=(6,), n=[1, 2, 3], seed=42) >>> rand.entries rand_n 1 2 3 0 True True True 1 False False True 2 False False False 3 False True False 4 False False True 5 False False False ``` Entry count can also be set per column: ```pycon >>> rand = vbt.RAND.run(input_shape=(8, 2), n=[np.array([1, 2]), 3], seed=42) >>> rand.entries rand_n 1 2 3 3 0 1 0 1 0 False False True False 1 True False False False 2 False False False True 3 False True True False 4 False False False False 5 False False False True 6 False False True False 7 False True False True ``` """ pass setattr(RAND, "__doc__", _RAND.__doc__) RAND.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RANDNX`.""" from vectorbtpro.indicators.configs import flex_col_param_config from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import rand_enex_apply_nb __all__ = [ "RANDNX", ] __pdoc__ = {} RANDNX = SignalFactory( class_name="RANDNX", module_name=__name__, short_name="randnx", mode="both", param_names=["n"], ).with_apply_func( rand_enex_apply_nb, require_input_shape=True, param_settings=dict( n=flex_col_param_config, ), kwargs_as_args=["entry_wait", "exit_wait"], entry_wait=1, exit_wait=1, seed=None, ) class _RANDNX(RANDNX): """Random entry and exit signal generator based on the number of signals. Generates `entries` and `exits` based on `vectorbtpro.signals.nb.rand_enex_apply_nb`. See `RAND` for notes on parameters. Usage: Test three different entry and exit counts: ```pycon >>> from vectorbtpro import * >>> randnx = vbt.RANDNX.run( ... input_shape=(6,), ... n=[1, 2, 3], ... seed=42) >>> randnx.entries randnx_n 1 2 3 0 True True True 1 False False False 2 False True True 3 False False False 4 False False True 5 False False False >>> randnx.exits randnx_n 1 2 3 0 False False False 1 True True True 2 False False False 3 False True True 4 False False False 5 False False True ``` """ pass setattr(RANDNX, "__doc__", _RANDNX.__doc__) RANDNX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RANDX`.""" import numpy as np import pandas as pd from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import rand_place_nb __all__ = [ "RANDX", ] __pdoc__ = {} RANDX = SignalFactory( class_name="RANDX", module_name=__name__, short_name="randx", mode="exits", ).with_place_func( exit_place_func_nb=rand_place_nb, exit_settings=dict( pass_kwargs=dict(n=np.array([1])), ), seed=None, ) class _RANDX(RANDX): """Random exit signal generator based on the number of signals. Generates `exits` based on `entries` and `vectorbtpro.signals.nb.rand_place_nb`. See `RAND` for notes on parameters. Usage: Generate an exit for each entry: ```pycon >>> from vectorbtpro import * >>> entries = pd.Series([True, False, False, True, False, False]) >>> randx = vbt.RANDX.run(entries, seed=42) >>> randx.exits 0 False 1 False 2 True 3 False 4 True 5 False dtype: bool ``` """ pass setattr(RANDX, "__doc__", _RANDX.__doc__) RANDX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RPROB`.""" import numpy as np from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import rand_by_prob_place_nb __all__ = [ "RPROB", ] __pdoc__ = {} RPROB = SignalFactory( class_name="RPROB", module_name=__name__, short_name="rprob", mode="entries", param_names=["prob"], ).with_place_func( entry_place_func_nb=rand_by_prob_place_nb, entry_settings=dict( pass_params=["prob"], pass_kwargs=["pick_first"], ), param_settings=dict( prob=flex_elem_param_config, ), seed=None, ) class _RPROB(RPROB): """Random entry signal generator based on probabilities. Generates `entries` based on `vectorbtpro.signals.nb.rand_by_prob_place_nb`. !!! hint All parameters can be either a single value (per frame) or a NumPy array (per row, column, or element). To generate multiple combinations, pass them as lists. Usage: Generate three columns with different entry probabilities: ```pycon >>> from vectorbtpro import * >>> rprob = vbt.RPROB.run(input_shape=(5,), prob=[0., 0.5, 1.], seed=42) >>> rprob.entries rprob_prob 0.0 0.5 1.0 0 False True True 1 False True True 2 False False True 3 False False True 4 False False True ``` Probability can also be set per row, column, or element: ```pycon >>> rprob = vbt.RPROB.run(input_shape=(5,), prob=np.array([0., 0., 1., 1., 1.]), seed=42) >>> rprob.entries 0 False 1 False 2 True 3 True 4 True Name: array_0, dtype: bool ``` """ pass setattr(RPROB, "__doc__", _RPROB.__doc__) RPROB.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RPROBCX`.""" from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.generators.rprobx import rprobx_config, rprobx_func_config __all__ = [ "RPROBCX", ] __pdoc__ = {} RPROBCX = SignalFactory( **rprobx_config.merge_with( dict( class_name="RPROBCX", module_name=__name__, short_name="rprobcx", mode="chain", ) ), ).with_place_func(**rprobx_func_config) class _RPROBCX(RPROBCX): """Random exit signal generator based on probabilities. Generates chain of `new_entries` and `exits` based on `entries` and `vectorbtpro.signals.nb.rand_by_prob_place_nb`. See `RPROB` for notes on parameters.""" pass setattr(RPROBCX, "__doc__", _RPROBCX.__doc__) RPROBCX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RPROBNX`.""" import numpy as np from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import rand_by_prob_place_nb __all__ = [ "RPROBNX", ] __pdoc__ = {} RPROBNX = SignalFactory( class_name="RPROBNX", module_name=__name__, short_name="rprobnx", mode="both", param_names=["entry_prob", "exit_prob"], ).with_place_func( entry_place_func_nb=rand_by_prob_place_nb, entry_settings=dict( pass_params=["entry_prob"], pass_kwargs=["pick_first"], ), exit_place_func_nb=rand_by_prob_place_nb, exit_settings=dict( pass_params=["exit_prob"], pass_kwargs=["pick_first"], ), param_settings=dict( entry_prob=flex_elem_param_config, exit_prob=flex_elem_param_config, ), seed=None, ) class _RPROBNX(RPROBNX): """Random entry and exit signal generator based on probabilities. Generates `entries` and `exits` based on `vectorbtpro.signals.nb.rand_by_prob_place_nb`. See `RPROB` for notes on parameters. Usage: Test all probability combinations: ```pycon >>> from vectorbtpro import * >>> rprobnx = vbt.RPROBNX.run( ... input_shape=(5,), ... entry_prob=[0.5, 1.], ... exit_prob=[0.5, 1.], ... param_product=True, ... seed=42) >>> rprobnx.entries rprobnx_entry_prob 0.5 0.5 1.0 0.5 rprobnx_exit_prob 0.5 1.0 0.5 1.0 0 True True True True 1 False False False False 2 False False False True 3 False False False False 4 False False True True >>> rprobnx.exits rprobnx_entry_prob 0.5 0.5 1.0 1.0 rprobnx_exit_prob 0.5 1.0 0.5 1.0 0 False False False False 1 False True False True 2 False False False False 3 False False True True 4 True False False False ``` Probabilities can also be set per row, column, or element: ```pycon >>> entry_prob1 = np.array([1., 0., 1., 0., 1.]) >>> entry_prob2 = np.array([0., 1., 0., 1., 0.]) >>> rprobnx = vbt.RPROBNX.run( ... input_shape=(5,), ... entry_prob=[entry_prob1, entry_prob2], ... exit_prob=1., ... seed=42) >>> rprobnx.entries rprobnx_entry_prob array_0 array_1 rprobnx_exit_prob 1.0 1.0 0 True False 1 False True 2 True False 3 False True 4 True False >>> rprobnx.exits rprobnx_entry_prob array_0 array_1 rprobnx_exit_prob 1.0 1.0 0 False False 1 True False 2 False True 3 True False 4 False True ``` """ pass setattr(RPROBNX, "__doc__", _RPROBNX.__doc__) RPROBNX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `RPROBX`.""" from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import rand_by_prob_place_nb from vectorbtpro.utils.config import ReadonlyConfig __all__ = [ "RPROBX", ] __pdoc__ = {} rprobx_config = ReadonlyConfig( dict( class_name="RPROBX", module_name=__name__, short_name="rprobx", mode="exits", param_names=["prob"], ), ) """Factory config for `RPROBX`.""" rprobx_func_config = ReadonlyConfig( dict( exit_place_func_nb=rand_by_prob_place_nb, exit_settings=dict( pass_params=["prob"], pass_kwargs=["pick_first"], ), param_settings=dict( prob=flex_elem_param_config, ), seed=None, ) ) """Exit function config for `RPROBX`.""" RPROBX = SignalFactory(**rprobx_config).with_place_func(**rprobx_func_config) class _RPROBX(RPROBX): """Random exit signal generator based on probabilities. Generates `exits` based on `entries` and `vectorbtpro.signals.nb.rand_by_prob_place_nb`. See `RPROB` for notes on parameters.""" pass setattr(RPROBX, "__doc__", _RPROBX.__doc__) RPROBX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `STCX`.""" from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.generators.stx import stx_config, stx_func_config __all__ = [ "STCX", ] __pdoc__ = {} STCX = SignalFactory( **stx_config.merge_with( dict( class_name="STCX", module_name=__name__, short_name="stcx", mode="chain", ) ) ).with_place_func(**stx_func_config) class _STCX(STCX): """Exit signal generator based on stop values. Generates chain of `new_entries` and `exits` based on `entries` and `vectorbtpro.signals.nb.stop_place_nb`. See `STX` for notes on parameters.""" pass setattr(STCX, "__doc__", _STCX.__doc__) STCX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Module with `STX`.""" import numpy as np from vectorbtpro.indicators.configs import flex_elem_param_config from vectorbtpro.signals.factory import SignalFactory from vectorbtpro.signals.nb import stop_place_nb from vectorbtpro.utils.config import ReadonlyConfig __all__ = [ "STX", ] __pdoc__ = {} stx_config = ReadonlyConfig( dict( class_name="STX", module_name=__name__, short_name="stx", mode="exits", input_names=["entry_ts", "ts", "follow_ts"], in_output_names=["stop_ts"], param_names=["stop", "trailing"], ) ) """Factory config for `STX`.""" stx_func_config = ReadonlyConfig( dict( exit_place_func_nb=stop_place_nb, exit_settings=dict( pass_inputs=["entry_ts", "ts", "follow_ts"], pass_in_outputs=["stop_ts"], pass_params=["stop", "trailing"], ), param_settings=dict( stop=flex_elem_param_config, trailing=flex_elem_param_config, ), trailing=False, ts=np.nan, follow_ts=np.nan, stop_ts=np.nan, ) ) """Exit function config for `STX`.""" STX = SignalFactory(**stx_config).with_place_func(**stx_func_config) class _STX(STX): """Exit signal generator based on stop values. Generates `exits` based on `entries` and `vectorbtpro.signals.nb.stop_place_nb`. !!! hint All parameters can be either a single value (per frame) or a NumPy array (per row, column, or element). To generate multiple combinations, pass them as lists.""" pass setattr(STX, "__doc__", _STX.__doc__) STX.fix_docstrings(__pdoc__) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Modules for working with signals.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.signals.accessors import * from vectorbtpro.signals.factory import * from vectorbtpro.signals.generators import * from vectorbtpro.signals.nb import * __exclude_from__all__ = [ "enums", ] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom Pandas accessors for signals. Methods can be accessed as follows: * `SignalsSRAccessor` -> `pd.Series.vbt.signals.*` * `SignalsDFAccessor` -> `pd.DataFrame.vbt.signals.*` ```pycon >>> from vectorbtpro import * >>> # vectorbtpro.signals.accessors.SignalsAccessor.pos_rank >>> pd.Series([False, True, True, True, False]).vbt.signals.pos_rank() 0 -1 1 0 2 1 3 2 4 -1 dtype: int64 ``` The accessors extend `vectorbtpro.generic.accessors`. !!! note The underlying Series/DataFrame must already be a signal series and have boolean data type. Grouping is only supported by the methods that accept the `group_by` argument. Accessors do not utilize caching. Run for the examples below: ```pycon >>> mask = pd.DataFrame({ ... 'a': [True, False, False, False, False], ... 'b': [True, False, True, False, True], ... 'c': [True, True, True, False, False] ... }, index=pd.date_range("2020", periods=5)) >>> mask a b c 2020-01-01 True True True 2020-01-02 False False True 2020-01-03 False True True 2020-01-04 False False False 2020-01-05 False True False ``` ## Stats !!! hint See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `SignalsAccessor.metrics`. ```pycon >>> mask.vbt.signals.stats(column='a') Start 2020-01-01 00:00:00 End 2020-01-05 00:00:00 Period 5 days 00:00:00 Total 1 Rate [%] 20.0 First Index 2020-01-01 00:00:00 Last Index 2020-01-01 00:00:00 Norm Avg Index [-1, 1] -1.0 Distance: Min NaT Distance: Median NaT Distance: Max NaT Total Partitions 1 Partition Rate [%] 100.0 Partition Length: Min 1 days 00:00:00 Partition Length: Median 1 days 00:00:00 Partition Length: Max 1 days 00:00:00 Partition Distance: Min NaT Partition Distance: Median NaT Partition Distance: Max NaT Name: a, dtype: object ``` We can pass another signal array to compare this array with: ```pycon >>> mask.vbt.signals.stats(column='a', settings=dict(target=mask['b'])) Start 2020-01-01 00:00:00 End 2020-01-05 00:00:00 Period 5 days 00:00:00 Total 1 Rate [%] 20.0 Total Overlapping 1 Overlapping Rate [%] 33.333333 First Index 2020-01-01 00:00:00 Last Index 2020-01-01 00:00:00 Norm Avg Index [-1, 1] -1.0 Distance -> Target: Min 0 days 00:00:00 Distance -> Target: Median 2 days 00:00:00 Distance -> Target: Max 4 days 00:00:00 Total Partitions 1 Partition Rate [%] 100.0 Partition Length: Min 1 days 00:00:00 Partition Length: Median 1 days 00:00:00 Partition Length: Max 1 days 00:00:00 Partition Distance: Min NaT Partition Distance: Median NaT Partition Distance: Max NaT Name: a, dtype: object ``` We can also return duration as a floating number rather than a timedelta: ```pycon >>> mask.vbt.signals.stats(column='a', settings=dict(to_timedelta=False)) Start 2020-01-01 00:00:00 End 2020-01-05 00:00:00 Period 5 Total 1 Rate [%] 20.0 First Index 2020-01-01 00:00:00 Last Index 2020-01-01 00:00:00 Norm Avg Index [-1, 1] -1.0 Distance: Min NaN Distance: Median NaN Distance: Max NaN Total Partitions 1 Partition Rate [%] 100.0 Partition Length: Min 1.0 Partition Length: Median 1.0 Partition Length: Max 1.0 Partition Distance: Min NaN Partition Distance: Median NaN Partition Distance: Max NaN Name: a, dtype: object ``` `SignalsAccessor.stats` also supports (re-)grouping: ```pycon >>> mask.vbt.signals.stats(column=0, group_by=[0, 0, 1]) Start 2020-01-01 00:00:00 End 2020-01-05 00:00:00 Period 5 days 00:00:00 Total 4 Rate [%] 40.0 First Index 2020-01-01 00:00:00 Last Index 2020-01-05 00:00:00 Norm Avg Index [-1, 1] -0.25 Distance: Min 2 days 00:00:00 Distance: Median 2 days 00:00:00 Distance: Max 2 days 00:00:00 Total Partitions 4 Partition Rate [%] 100.0 Partition Length: Min 1 days 00:00:00 Partition Length: Median 1 days 00:00:00 Partition Length: Max 1 days 00:00:00 Partition Distance: Min 2 days 00:00:00 Partition Distance: Median 2 days 00:00:00 Partition Distance: Max 2 days 00:00:00 Name: 0, dtype: object ``` ## Plots !!! hint See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `SignalsAccessor.subplots`. This class inherits subplots from `vectorbtpro.generic.accessors.GenericAccessor`. """ from functools import partialmethod import numpy as np import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.accessors import register_vbt_accessor, register_df_vbt_accessor, register_sr_vbt_accessor from vectorbtpro.base import chunking as base_ch, reshaping, indexes from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.accessors import GenericAccessor, GenericSRAccessor, GenericDFAccessor from vectorbtpro.generic.ranges import Ranges from vectorbtpro.records.mapped_array import MappedArray from vectorbtpro.registries.ch_registry import ch_reg from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.signals import nb, enums from vectorbtpro.utils import checks from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.colors import adjust_lightness from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig from vectorbtpro.utils.decorators import hybrid_method, hybrid_property from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.random_ import set_seed_nb from vectorbtpro.utils.template import RepEval, substitute_templates from vectorbtpro.utils.warnings_ import warn __all__ = [ "SignalsAccessor", "SignalsSRAccessor", "SignalsDFAccessor", ] __pdoc__ = {} @register_vbt_accessor("signals") class SignalsAccessor(GenericAccessor): """Accessor on top of signal series. For both, Series and DataFrames. Accessible via `pd.Series.vbt.signals` and `pd.DataFrame.vbt.signals`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: GenericAccessor.__init__(self, wrapper, obj=obj, **kwargs) checks.assert_dtype(self._obj, np.bool_) @hybrid_property def sr_accessor_cls(cls_or_self) -> tp.Type["SignalsSRAccessor"]: """Accessor class for `pd.Series`.""" return SignalsSRAccessor @hybrid_property def df_accessor_cls(cls_or_self) -> tp.Type["SignalsDFAccessor"]: """Accessor class for `pd.DataFrame`.""" return SignalsDFAccessor # ############# Overriding ############# # @classmethod def empty(cls, *args, fill_value: bool = False, **kwargs) -> tp.SeriesFrame: """`vectorbtpro.base.accessors.BaseAccessor.empty` with `fill_value=False`.""" return GenericAccessor.empty(*args, fill_value=fill_value, dtype=np.bool_, **kwargs) @classmethod def empty_like(cls, *args, fill_value: bool = False, **kwargs) -> tp.SeriesFrame: """`vectorbtpro.base.accessors.BaseAccessor.empty_like` with `fill_value=False`.""" return GenericAccessor.empty_like(*args, fill_value=fill_value, dtype=np.bool_, **kwargs) bshift = partialmethod(GenericAccessor.bshift, fill_value=False) fshift = partialmethod(GenericAccessor.fshift, fill_value=False) ago = partialmethod(GenericAccessor.ago, fill_value=False) realign = partialmethod(GenericAccessor.realign, nan_value=False) # ############# Generation ############# # @classmethod def generate( cls, shape: tp.Union[tp.ShapeLike, ArrayWrapper], place_func_nb: tp.PlaceFunc, *args, place_args: tp.ArgsLike = None, only_once: bool = True, wait: int = 1, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.signals.nb.generate_nb`. `shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper` (will be used as `wrapper`). Arguments to `place_func_nb` can be passed either as `*args` or `place_args` (but not both!). Usage: * Generate random signals manually: ```pycon >>> @njit ... def place_func_nb(c): ... i = np.random.choice(len(c.out)) ... c.out[i] = True ... return i >>> vbt.pd_acc.signals.generate( ... (5, 3), ... place_func_nb, ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) a b c 2020-01-01 True False False 2020-01-02 False True False 2020-01-03 False False True 2020-01-04 False False False 2020-01-05 False False False ``` """ if isinstance(shape, ArrayWrapper): wrapper = shape shape = wrapper.shape if len(args) > 0 and place_args is not None: raise ValueError("Must provide either *args or place_args, not both") if place_args is None: place_args = args if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} shape_2d = cls.resolve_shape(shape) if len(broadcast_named_args) > 0: broadcast_named_args = reshaping.broadcast(broadcast_named_args, to_shape=shape_2d, **broadcast_kwargs) template_context = merge_dicts( broadcast_named_args, dict(shape=shape, shape_2d=shape_2d, wait=wait), template_context, ) place_args = substitute_templates(place_args, template_context, eval_id="place_args") func = jit_reg.resolve_option(nb.generate_nb, jitted) func = ch_reg.resolve_option(func, chunked) result = func( target_shape=shape_2d, place_func_nb=place_func_nb, place_args=place_args, only_once=only_once, wait=wait, ) if wrapper is None: wrapper = ArrayWrapper.from_shape(shape, ndim=cls.ndim) if wrap_kwargs is None: wrap_kwargs = resolve_dict(wrap_kwargs) return wrapper.wrap(result, **wrap_kwargs) @classmethod def generate_both( cls, shape: tp.Union[tp.ShapeLike, ArrayWrapper], entry_place_func_nb: tp.PlaceFunc, exit_place_func_nb: tp.PlaceFunc, *args, entry_place_args: tp.ArgsLike = None, exit_place_args: tp.ArgsLike = None, entry_wait: int = 1, exit_wait: int = 1, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]: """See `vectorbtpro.signals.nb.generate_enex_nb`. `shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper` (will be used as `wrapper`). Arguments to `entry_place_func_nb` can be passed either as `*args` or `entry_place_args` while arguments to `exit_place_func_nb` can be passed either as `*args` or `exit_place_args` (but not both!). Usage: * Generate entry and exit signals one after another: ```pycon >>> @njit ... def place_func_nb(c): ... c.out[0] = True ... return 0 >>> en, ex = vbt.pd_acc.signals.generate_both( ... (5, 3), ... entry_place_func_nb=place_func_nb, ... exit_place_func_nb=place_func_nb, ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) >>> en a b c 2020-01-01 True True True 2020-01-02 False False False 2020-01-03 True True True 2020-01-04 False False False 2020-01-05 True True True >>> ex a b c 2020-01-01 False False False 2020-01-02 True True True 2020-01-03 False False False 2020-01-04 True True True 2020-01-05 False False False ``` * Generate three entries and one exit one after another: ```pycon >>> @njit ... def entry_place_func_nb(c, n): ... c.out[:n] = True ... return n - 1 >>> @njit ... def exit_place_func_nb(c, n): ... c.out[:n] = True ... return n - 1 >>> en, ex = vbt.pd_acc.signals.generate_both( ... (5, 3), ... entry_place_func_nb=entry_place_func_nb, ... entry_place_args=(3,), ... exit_place_func_nb=exit_place_func_nb, ... exit_place_args=(1,), ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) >>> en a b c 2020-01-01 True True True 2020-01-02 True True True 2020-01-03 True True True 2020-01-04 False False False 2020-01-05 True True True >>> ex a b c 2020-01-01 False False False 2020-01-02 False False False 2020-01-03 False False False 2020-01-04 True True True 2020-01-05 False False False ``` """ if isinstance(shape, ArrayWrapper): wrapper = shape shape = wrapper.shape if len(args) > 0 and entry_place_args is not None: raise ValueError("Must provide either *args or entry_place_args, not both") if len(args) > 0 and exit_place_args is not None: raise ValueError("Must provide either *args or exit_place_args, not both") if entry_place_args is None: entry_place_args = args if exit_place_args is None: exit_place_args = args if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} shape_2d = cls.resolve_shape(shape) if len(broadcast_named_args) > 0: broadcast_named_args = reshaping.broadcast( broadcast_named_args, to_shape=shape_2d, **broadcast_kwargs, ) template_context = merge_dicts( broadcast_named_args, dict( shape=shape, shape_2d=shape_2d, entry_wait=entry_wait, exit_wait=exit_wait, ), template_context, ) entry_place_args = substitute_templates(entry_place_args, template_context, eval_id="entry_place_args") exit_place_args = substitute_templates(exit_place_args, template_context, eval_id="exit_place_args") func = jit_reg.resolve_option(nb.generate_enex_nb, jitted) func = ch_reg.resolve_option(func, chunked) result1, result2 = func( target_shape=shape_2d, entry_place_func_nb=entry_place_func_nb, entry_place_args=entry_place_args, exit_place_func_nb=exit_place_func_nb, exit_place_args=exit_place_args, entry_wait=entry_wait, exit_wait=exit_wait, ) if wrapper is None: wrapper = ArrayWrapper.from_shape(shape, ndim=cls.ndim) if wrap_kwargs is None: wrap_kwargs = resolve_dict(wrap_kwargs) return wrapper.wrap(result1, **wrap_kwargs), wrapper.wrap(result2, **wrap_kwargs) def generate_exits( self, exit_place_func_nb: tp.PlaceFunc, *args, exit_place_args: tp.ArgsLike = None, wait: int = 1, until_next: bool = True, skip_until_exit: bool = False, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.signals.nb.generate_ex_nb`. Usage: * Generate an exit just before the next entry: ```pycon >>> @njit ... def exit_place_func_nb(c): ... c.out[-1] = True ... return len(c.out) - 1 >>> mask.vbt.signals.generate_exits(exit_place_func_nb) a b c 2020-01-01 False False False 2020-01-02 False True False 2020-01-03 False False False 2020-01-04 False True False 2020-01-05 True False True ``` """ if len(args) > 0 and exit_place_args is not None: raise ValueError("Must provide either *args or exit_place_args, not both") if exit_place_args is None: exit_place_args = args if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} obj = self.obj if len(broadcast_named_args) > 0: broadcast_named_args = {"obj": obj, **broadcast_named_args} broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) obj = broadcast_named_args["obj"] else: wrapper = self.wrapper obj = reshaping.to_2d_array(obj) template_context = merge_dicts( broadcast_named_args, dict(wait=wait, until_next=until_next, skip_until_exit=skip_until_exit), template_context, ) exit_place_args = substitute_templates(exit_place_args, template_context, eval_id="exit_place_args") func = jit_reg.resolve_option(nb.generate_ex_nb, jitted) func = ch_reg.resolve_option(func, chunked) exits = func( entries=obj, exit_place_func_nb=exit_place_func_nb, exit_place_args=exit_place_args, wait=wait, until_next=until_next, skip_until_exit=skip_until_exit, ) return wrapper.wrap(exits, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Cleaning ############# # @hybrid_method def clean( cls_or_self, *objs, force_first: bool = True, keep_conflicts: bool = False, reverse_order: bool = False, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeTuple[tp.SeriesFrame]: """Clean signals. If one array is passed, see `SignalsAccessor.first`. If two arrays passed, entries and exits, see `vectorbtpro.signals.nb.clean_enex_nb`.""" if broadcast_kwargs is None: broadcast_kwargs = {} if wrap_kwargs is None: wrap_kwargs = {} if not isinstance(cls_or_self, type): objs = (cls_or_self.obj, *objs) if len(objs) == 1: obj = objs[0] if not isinstance(obj, (pd.Series, pd.DataFrame)): obj = ArrayWrapper.from_obj(obj).wrap(obj) return obj.vbt.signals.first(wrap_kwargs=wrap_kwargs, jitted=jitted, chunked=chunked) if len(objs) == 2: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) broadcasted_args, wrapper = reshaping.broadcast( dict(entries=objs[0], exits=objs[1]), return_wrapper=True, **broadcast_kwargs, ) func = jit_reg.resolve_option(nb.clean_enex_nb, jitted) func = ch_reg.resolve_option(func, chunked) entries_out, exits_out = func( entries=broadcasted_args["entries"], exits=broadcasted_args["exits"], force_first=force_first, keep_conflicts=keep_conflicts, reverse_order=reverse_order, ) return ( wrapper.wrap(entries_out, group_by=False, **wrap_kwargs), wrapper.wrap(exits_out, group_by=False, **wrap_kwargs), ) raise ValueError("This method accepts either one or two arrays") # ############# Random signals ############# # @classmethod def generate_random( cls, shape: tp.Union[tp.ShapeLike, ArrayWrapper], n: tp.Optional[tp.ArrayLike] = None, prob: tp.Optional[tp.ArrayLike] = None, pick_first: bool = False, seed: tp.Optional[int] = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.SeriesFrame: """Generate signals randomly. `shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper` (will be used as `wrapper`). If `n` is set, uses `vectorbtpro.signals.nb.rand_place_nb`. If `prob` is set, uses `vectorbtpro.signals.nb.rand_by_prob_place_nb`. For arguments, see `SignalsAccessor.generate`. `n` must be either a scalar or an array that will broadcast to the number of columns. `prob` must be either a single number or an array that will broadcast to match `shape`. Specify `seed` to make output deterministic. Usage: * For each column, generate a variable number of signals: ```pycon >>> vbt.pd_acc.signals.generate_random( ... (5, 3), ... n=[0, 1, 2], ... seed=42, ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) a b c 2020-01-01 False False False 2020-01-02 False False False 2020-01-03 False False True 2020-01-04 False True False 2020-01-05 False False True ``` * For each column and time step, pick a signal with 50% probability: ```pycon >>> vbt.pd_acc.signals.generate_random( ... (5, 3), ... prob=0.5, ... seed=42, ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) a b c 2020-01-01 True True True 2020-01-02 False True False 2020-01-03 False False False 2020-01-04 False False True 2020-01-05 True False True ``` """ if isinstance(shape, ArrayWrapper): if "wrapper" in kwargs: raise ValueError("Must provide wrapper either via shape or wrapper, not both") kwargs["wrapper"] = shape shape = kwargs["wrapper"].shape shape_2d = cls.resolve_shape(shape) if n is not None and prob is not None: raise ValueError("Must provide either n or prob, not both") if seed is not None: set_seed_nb(seed) if n is not None: n = reshaping.broadcast_array_to(n, shape_2d[1]) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(), ), ), ) return cls.generate( shape, jit_reg.resolve_option(nb.rand_place_nb, jitted), n, jitted=jitted, chunked=chunked, **kwargs, ) if prob is not None: prob = reshaping.to_2d_array(reshaping.broadcast_array_to(prob, shape)) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), None, None, ), ), ) return cls.generate( shape, jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted), prob, pick_first, jitted=jitted, chunked=chunked, **kwargs, ) raise ValueError("Must provide at least n or prob") @classmethod def generate_random_both( cls, shape: tp.Union[tp.ShapeLike, ArrayWrapper], n: tp.Optional[tp.ArrayLike] = None, entry_prob: tp.Optional[tp.ArrayLike] = None, exit_prob: tp.Optional[tp.ArrayLike] = None, seed: tp.Optional[int] = None, entry_wait: int = 1, exit_wait: int = 1, entry_pick_first: bool = True, exit_pick_first: bool = True, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrapper: tp.Optional[ArrayWrapper] = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]: """Generate chain of entry and exit signals randomly. `shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper` (will be used as `wrapper`). If `n` is set, uses `vectorbtpro.signals.nb.generate_rand_enex_nb`. If `entry_prob` and `exit_prob` are set, uses `SignalsAccessor.generate_both` with `vectorbtpro.signals.nb.rand_by_prob_place_nb`. Usage: * For each column, generate two entries and exits randomly: ```pycon >>> en, ex = vbt.pd_acc.signals.generate_random_both( ... (5, 3), ... n=2, ... seed=42, ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) >>> en a b c 2020-01-01 False False True 2020-01-02 True True False 2020-01-03 False False False 2020-01-04 True True True 2020-01-05 False False False >>> ex a b c 2020-01-01 False False False 2020-01-02 False False True 2020-01-03 True True False 2020-01-04 False False False 2020-01-05 True True True ``` * For each column and time step, pick entry with 50% probability and exit right after: ```pycon >>> en, ex = vbt.pd_acc.signals.generate_random_both( ... (5, 3), ... entry_prob=0.5, ... exit_prob=1., ... seed=42, ... wrap_kwargs=dict( ... index=mask.index, ... columns=mask.columns ... ) ... ) >>> en a b c 2020-01-01 True True True 2020-01-02 False False False 2020-01-03 False False False 2020-01-04 False False True 2020-01-05 True False False >>> ex a b c 2020-01-01 False False False 2020-01-02 True True True 2020-01-03 False False False 2020-01-04 False False False 2020-01-05 False False True ``` """ if isinstance(shape, ArrayWrapper): wrapper = shape shape = wrapper.shape shape_2d = cls.resolve_shape(shape) if n is not None and (entry_prob is not None or exit_prob is not None): raise ValueError("Must provide either n or any of the entry_prob and exit_prob, not both") if seed is not None: set_seed_nb(seed) if n is not None: n = reshaping.broadcast_array_to(n, shape_2d[1]) func = jit_reg.resolve_option(nb.generate_rand_enex_nb, jitted) func = ch_reg.resolve_option(func, chunked) entries, exits = func(shape_2d, n, entry_wait, exit_wait) if wrapper is None: wrapper = ArrayWrapper.from_shape(shape, ndim=cls.ndim) if wrap_kwargs is None: wrap_kwargs = resolve_dict(wrap_kwargs) return wrapper.wrap(entries, **wrap_kwargs), wrapper.wrap(exits, **wrap_kwargs) elif entry_prob is not None and exit_prob is not None: entry_prob = reshaping.to_2d_array(reshaping.broadcast_array_to(entry_prob, shape)) exit_prob = reshaping.to_2d_array(reshaping.broadcast_array_to(exit_prob, shape)) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( entry_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), None, ), exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), None, ), ), ) return cls.generate_both( shape, entry_place_func_nb=jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted), entry_place_args=(entry_prob, entry_pick_first), exit_place_func_nb=jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted), exit_place_args=(exit_prob, exit_pick_first), entry_wait=entry_wait, exit_wait=exit_wait, jitted=jitted, chunked=chunked, wrapper=wrapper, wrap_kwargs=wrap_kwargs, ) raise ValueError("Must provide at least n, or entry_prob and exit_prob") def generate_random_exits( self, prob: tp.Optional[tp.ArrayLike] = None, seed: tp.Optional[int] = None, wait: int = 1, until_next: bool = True, skip_until_exit: bool = False, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Generate exit signals randomly. If `prob` is None, uses `vectorbtpro.signals.nb.rand_place_nb`. Otherwise, uses `vectorbtpro.signals.nb.rand_by_prob_place_nb`. Uses `SignalsAccessor.generate_exits`. Specify `seed` to make output deterministic. Usage: * After each entry in `mask`, generate exactly one exit: ```pycon >>> mask.vbt.signals.generate_random_exits(seed=42) a b c 2020-01-01 False False False 2020-01-02 False True False 2020-01-03 False False False 2020-01-04 True True False 2020-01-05 False False True ``` * After each entry in `mask` and at each time step, generate exit with 50% probability: ```pycon >>> mask.vbt.signals.generate_random_exits(prob=0.5, seed=42) a b c 2020-01-01 False False False 2020-01-02 True False False 2020-01-03 False False False 2020-01-04 False False False 2020-01-05 False False True ``` """ if seed is not None: set_seed_nb(seed) if prob is not None: broadcast_kwargs = merge_dicts( dict(keep_flex=dict(obj=False, prob=True)), broadcast_kwargs, ) broadcasted_args = reshaping.broadcast( dict(obj=self.obj, prob=prob), **broadcast_kwargs, ) obj = broadcasted_args["obj"] prob = broadcasted_args["prob"] chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), None, ) ), ) return obj.vbt.signals.generate_exits( jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted), prob, True, wait=wait, until_next=until_next, skip_until_exit=skip_until_exit, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) n = reshaping.broadcast_array_to(1, self.wrapper.shape_2d[1]) chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(), ) ), ) return self.generate_exits( jit_reg.resolve_option(nb.rand_place_nb, jitted), n, wait=wait, until_next=until_next, skip_until_exit=skip_until_exit, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) # ############# Stop signals ############# # def generate_stop_exits( self, entry_ts: tp.ArrayLike, ts: tp.ArrayLike = np.nan, follow_ts: tp.ArrayLike = np.nan, stop: tp.ArrayLike = np.nan, trailing: tp.ArrayLike = False, out_dict: tp.Optional[tp.Dict[str, tp.ArrayLike]] = None, entry_wait: int = 1, exit_wait: int = 1, until_next: bool = True, skip_until_exit: bool = False, chain: bool = False, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeTuple[tp.SeriesFrame]: """Generate exits based on when `ts` hits the stop. For arguments, see `vectorbtpro.signals.nb.stop_place_nb`. If `chain` is True, uses `SignalsAccessor.generate_both`. Otherwise, uses `SignalsAccessor.generate_exits`. Use `out_dict` as a dict to pass `stop_ts` array. You can also set `out_dict` to {} to produce this array automatically and still have access to it. All array-like arguments including stops and `out_dict` will broadcast using `vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`. !!! hint Default arguments will generate an exit signal strictly between two entry signals. If both entry signals are too close to each other, no exit will be generated. To ignore all entries that come between an entry and its exit, set `until_next` to False and `skip_until_exit` to True. To remove all entries that come between an entry and its exit, set `chain` to True. This will return two arrays: new entries and exits. Usage: * Regular stop loss: ```pycon >>> ts = pd.Series([1, 2, 3, 2, 1]) >>> mask.vbt.signals.generate_stop_exits(ts, stop=-0.1) a b c 2020-01-01 False False False 2020-01-02 False False False 2020-01-03 False False False 2020-01-04 False True True 2020-01-05 False False False ``` * Trailing stop loss: ```pycon >>> mask.vbt.signals.generate_stop_exits(ts, stop=-0.1, trailing=True) a b c 2020-01-01 False False False 2020-01-02 False False False 2020-01-03 False False False 2020-01-04 True True True 2020-01-05 False False False ``` * Testing multiple take profit stops: ```pycon >>> mask.vbt.signals.generate_stop_exits(ts, stop=vbt.Param([1.0, 1.5])) stop 1.0 1.5 a b c a b c 2020-01-01 False False False False False False 2020-01-02 True True False False False False 2020-01-03 False False False True False False 2020-01-04 False False False False False False 2020-01-05 False False False False False False ``` """ if wrap_kwargs is None: wrap_kwargs = {} entries = self.obj if out_dict is None: out_dict_passed = False out_dict = {} else: out_dict_passed = True stop_ts = out_dict.get("stop_ts", np.nan if out_dict_passed else None) broadcastable_args = dict( entries=entries, entry_ts=entry_ts, ts=ts, follow_ts=follow_ts, stop=stop, trailing=trailing, stop_ts=stop_ts, ) broadcast_kwargs = merge_dicts( dict( keep_flex=dict(entries=False, stop_ts=False, _def=True), require_kwargs=dict(requirements="W"), ), broadcast_kwargs, ) broadcasted_args = reshaping.broadcast(broadcastable_args, **broadcast_kwargs) entries = broadcasted_args["entries"] stop_ts = broadcasted_args["stop_ts"] if stop_ts is None: stop_ts = np.empty_like(entries, dtype=float_) stop_ts = reshaping.to_2d_array(stop_ts) entries_arr = reshaping.to_2d_array(entries) wrapper = ArrayWrapper.from_obj(entries) if chain: if checks.is_series(entries): cls = self.sr_accessor_cls else: cls = self.df_accessor_cls chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( entry_place_args=ch.ArgsTaker( ch.ArraySlicer(axis=1), ), exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), ), ), ) out_dict["stop_ts"] = wrapper.wrap(stop_ts, group_by=False, **wrap_kwargs) return cls.generate_both( entries.shape, entry_place_func_nb=jit_reg.resolve_option(nb.first_place_nb, jitted), entry_place_args=(entries_arr,), exit_place_func_nb=jit_reg.resolve_option(nb.stop_place_nb, jitted), exit_place_args=( broadcasted_args["entry_ts"], broadcasted_args["ts"], broadcasted_args["follow_ts"], stop_ts, broadcasted_args["stop"], broadcasted_args["trailing"], ), entry_wait=entry_wait, exit_wait=exit_wait, wrapper=wrapper, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) else: chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), ) ), ) if skip_until_exit and until_next: warn("skip_until_exit=True has only effect when until_next=False") out_dict["stop_ts"] = wrapper.wrap(stop_ts, group_by=False, **wrap_kwargs) return entries.vbt.signals.generate_exits( jit_reg.resolve_option(nb.stop_place_nb, jitted), broadcasted_args["entry_ts"], broadcasted_args["ts"], broadcasted_args["follow_ts"], stop_ts, broadcasted_args["stop"], broadcasted_args["trailing"], wait=exit_wait, until_next=until_next, skip_until_exit=skip_until_exit, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) def generate_ohlc_stop_exits( self, entry_price: tp.ArrayLike, open: tp.ArrayLike = np.nan, high: tp.ArrayLike = np.nan, low: tp.ArrayLike = np.nan, close: tp.ArrayLike = np.nan, sl_stop: tp.ArrayLike = np.nan, tsl_th: tp.ArrayLike = np.nan, tsl_stop: tp.ArrayLike = np.nan, tp_stop: tp.ArrayLike = np.nan, reverse: tp.ArrayLike = False, is_entry_open: bool = False, out_dict: tp.Optional[tp.Dict[str, tp.ArrayLike]] = None, entry_wait: int = 1, exit_wait: int = 1, until_next: bool = True, skip_until_exit: bool = False, chain: bool = False, broadcast_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeTuple[tp.SeriesFrame]: """Generate exits based on when the price hits (trailing) stop loss or take profit. Use `out_dict` as a dict to pass `stop_price` and `stop_type` arrays. You can also set `out_dict` to {} to produce these arrays automatically and still have access to them. For arguments, see `vectorbtpro.signals.nb.ohlc_stop_place_nb`. If `chain` is True, uses `SignalsAccessor.generate_both`. Otherwise, uses `SignalsAccessor.generate_exits`. All array-like arguments including stops and `out_dict` will broadcast using `vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`. For arguments, see `vectorbtpro.signals.nb.ohlc_stop_place_nb`. !!! hint Default arguments will generate an exit signal strictly between two entry signals. If both entry signals are too close to each other, no exit will be generated. To ignore all entries that come between an entry and its exit, set `until_next` to False and `skip_until_exit` to True. To remove all entries that come between an entry and its exit, set `chain` to True. This will return two arrays: new entries and exits. Usage: * Generate exits for TSL and TP of 10%: ```pycon >>> price = pd.DataFrame({ ... 'open': [10, 11, 12, 11, 10], ... 'high': [11, 12, 13, 12, 11], ... 'low': [9, 10, 11, 10, 9], ... 'close': [10, 11, 12, 11, 10] ... }) >>> out_dict = {} >>> exits = mask.vbt.signals.generate_ohlc_stop_exits( ... price["open"], ... price['open'], ... price['high'], ... price['low'], ... price['close'], ... tsl_stop=0.1, ... tp_stop=0.1, ... is_entry_open=True, ... out_dict=out_dict, ... ) >>> exits a b c 2020-01-01 False False False 2020-01-02 True True False 2020-01-03 False False False 2020-01-04 False True True 2020-01-05 False False False >>> out_dict['stop_price'] a b c 2020-01-01 NaN NaN NaN 2020-01-02 11.0 11.0 NaN 2020-01-03 NaN NaN NaN 2020-01-04 NaN 10.8 10.8 2020-01-05 NaN NaN NaN >>> out_dict['stop_type'].vbt(mapping=vbt.sig_enums.StopType).apply_mapping() a b c 2020-01-01 None None None 2020-01-02 TP TP None 2020-01-03 None None None 2020-01-04 None TSL TSL 2020-01-05 None None None ``` Notice how the first two entry signals in the third column have no exit signal - there is no room between them for an exit signal. * To find an exit for the first entry and ignore all entries that are in-between them, we can pass `until_next=False` and `skip_until_exit=True`: ```pycon >>> out_dict = {} >>> exits = mask.vbt.signals.generate_ohlc_stop_exits( ... price['open'], ... price['open'], ... price['high'], ... price['low'], ... price['close'], ... tsl_stop=0.1, ... tp_stop=0.1, ... is_entry_open=True, ... out_dict=out_dict, ... until_next=False, ... skip_until_exit=True ... ) >>> exits a b c 2020-01-01 False False False 2020-01-02 True True True 2020-01-03 False False False 2020-01-04 False True True 2020-01-05 False False False >>> out_dict['stop_price'] a b c 2020-01-01 NaN NaN NaN 2020-01-02 11.0 11.0 11.0 2020-01-03 NaN NaN NaN 2020-01-04 NaN 10.8 10.8 2020-01-05 NaN NaN NaN >>> out_dict['stop_type'].vbt(mapping=vbt.sig_enums.StopType).apply_mapping() a b c 2020-01-01 None None None 2020-01-02 TP TP TP 2020-01-03 None None None 2020-01-04 None TSL TSL 2020-01-05 None None None ``` Now, the first signal in the third column gets executed regardless of the entries that come next, which is very similar to the logic that is implemented in `vectorbtpro.portfolio.base.Portfolio.from_signals`. * To automatically remove all ignored entry signals, pass `chain=True`. This will return a new entries array: ```pycon >>> out_dict = {} >>> new_entries, exits = mask.vbt.signals.generate_ohlc_stop_exits( ... price['open'], ... price['open'], ... price['high'], ... price['low'], ... price['close'], ... tsl_stop=0.1, ... tp_stop=0.1, ... is_entry_open=True, ... out_dict=out_dict, ... chain=True ... ) >>> new_entries a b c 2020-01-01 True True True 2020-01-02 False False False << removed entry in the third column 2020-01-03 False True True 2020-01-04 False False False 2020-01-05 False True False >>> exits a b c 2020-01-01 False False False 2020-01-02 True True True 2020-01-03 False False False 2020-01-04 False True True 2020-01-05 False False False ``` !!! warning The last two examples above make entries dependent upon exits - this makes only sense if you have no other exit arrays to combine this stop exit array with. * Test multiple parameter combinations: ```pycon >>> exits = mask.vbt.signals.generate_ohlc_stop_exits( ... price['open'], ... price['open'], ... price['high'], ... price['low'], ... price['close'], ... sl_stop=vbt.Param([False, 0.1]), ... tsl_stop=vbt.Param([False, 0.1]), ... is_entry_open=True ... ) >>> exits sl_stop False 0.1 \\ tsl_stop False 0.1 False a b c a b c a b c 2020-01-01 False False False False False False False False False 2020-01-02 False False False False False False False False False 2020-01-03 False False False False False False False False False 2020-01-04 False False False True True True False True True 2020-01-05 False False False False False False True False False sl_stop tsl_stop 0.1 a b c 2020-01-01 False False False 2020-01-02 False False False 2020-01-03 False False False 2020-01-04 True True True 2020-01-05 False False False ``` """ if wrap_kwargs is None: wrap_kwargs = {} entries = self.obj if out_dict is None: out_dict_passed = False out_dict = {} else: out_dict_passed = True stop_price = out_dict.get("stop_price", np.nan if out_dict_passed else None) stop_type = out_dict.get("stop_type", -1 if out_dict_passed else None) broadcastable_args = dict( entries=entries, entry_price=entry_price, open=open, high=high, low=low, close=close, sl_stop=sl_stop, tsl_th=tsl_th, tsl_stop=tsl_stop, tp_stop=tp_stop, reverse=reverse, stop_price=stop_price, stop_type=stop_type, ) broadcast_kwargs = merge_dicts( dict( keep_flex=dict(entries=False, stop_price=False, stop_type=False, _def=True), require_kwargs=dict(requirements="W"), ), broadcast_kwargs, ) broadcasted_args = reshaping.broadcast(broadcastable_args, **broadcast_kwargs) entries = broadcasted_args["entries"] stop_price = broadcasted_args["stop_price"] if stop_price is None: stop_price = np.empty_like(entries, dtype=float_) stop_price = reshaping.to_2d_array(stop_price) stop_type = broadcasted_args["stop_type"] if stop_type is None: stop_type = np.empty_like(entries, dtype=int_) stop_type = reshaping.to_2d_array(stop_type) entries_arr = reshaping.to_2d_array(entries) wrapper = ArrayWrapper.from_obj(entries) if chain: if checks.is_series(entries): cls = self.sr_accessor_cls else: cls = self.df_accessor_cls chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( entry_place_args=ch.ArgsTaker( ch.ArraySlicer(axis=1), ), exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), ch.ArraySlicer(axis=1), ch.ArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), None, ), ), ) new_entries, exits = cls.generate_both( entries.shape, entry_place_func_nb=jit_reg.resolve_option(nb.first_place_nb, jitted), entry_place_args=(entries_arr,), exit_place_func_nb=jit_reg.resolve_option(nb.ohlc_stop_place_nb, jitted), exit_place_args=( broadcasted_args["entry_price"], broadcasted_args["open"], broadcasted_args["high"], broadcasted_args["low"], broadcasted_args["close"], stop_price, stop_type, broadcasted_args["sl_stop"], broadcasted_args["tsl_th"], broadcasted_args["tsl_stop"], broadcasted_args["tp_stop"], broadcasted_args["reverse"], is_entry_open, ), entry_wait=entry_wait, exit_wait=exit_wait, wrapper=wrapper, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) out_dict["stop_price"] = wrapper.wrap(stop_price, group_by=False, **wrap_kwargs) out_dict["stop_type"] = wrapper.wrap(stop_type, group_by=False, **wrap_kwargs) return new_entries, exits else: if skip_until_exit and until_next: warn("skip_until_exit=True has only effect when until_next=False") chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( exit_place_args=ch.ArgsTaker( base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), ch.ArraySlicer(axis=1), ch.ArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), base_ch.FlexArraySlicer(axis=1), None, ) ), ) exits = entries.vbt.signals.generate_exits( jit_reg.resolve_option(nb.ohlc_stop_place_nb, jitted), broadcasted_args["entry_price"], broadcasted_args["open"], broadcasted_args["high"], broadcasted_args["low"], broadcasted_args["close"], stop_price, stop_type, broadcasted_args["sl_stop"], broadcasted_args["tsl_th"], broadcasted_args["tsl_stop"], broadcasted_args["tp_stop"], broadcasted_args["reverse"], is_entry_open, wait=exit_wait, until_next=until_next, skip_until_exit=skip_until_exit, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs, ) out_dict["stop_price"] = wrapper.wrap(stop_price, group_by=False, **wrap_kwargs) out_dict["stop_type"] = wrapper.wrap(stop_type, group_by=False, **wrap_kwargs) return exits # ############# Ranking ############# # def rank( self, rank_func_nb: tp.RankFunc, *args, rank_args: tp.ArgsLike = None, reset_by: tp.Optional[tp.ArrayLike] = None, after_false: bool = False, after_reset: bool = False, reset_wait: int = 1, as_mapped: bool = False, broadcast_named_args: tp.KwargsLike = None, broadcast_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Union[tp.SeriesFrame, MappedArray]: """See `vectorbtpro.signals.nb.rank_nb`. Arguments to `rank_func_nb` can be passed either as `*args` or `rank_args` (but not both!). Will broadcast with `reset_by` using `vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`. Set `as_mapped` to True to return an instance of `vectorbtpro.records.mapped_array.MappedArray`.""" if len(args) > 0 and rank_args is not None: raise ValueError("Must provide either *args or rank_args, not both") if rank_args is None: rank_args = args if broadcast_named_args is None: broadcast_named_args = {} if broadcast_kwargs is None: broadcast_kwargs = {} if template_context is None: template_context = {} if wrap_kwargs is None: wrap_kwargs = {} if reset_by is not None: broadcast_named_args = {"obj": self.obj, "reset_by": reset_by, **broadcast_named_args} else: broadcast_named_args = {"obj": self.obj, **broadcast_named_args} if len(broadcast_named_args) > 1: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) broadcast_named_args, wrapper = reshaping.broadcast( broadcast_named_args, return_wrapper=True, **broadcast_kwargs, ) else: wrapper = self.wrapper obj = reshaping.to_2d_array(broadcast_named_args["obj"]) if reset_by is not None: reset_by = reshaping.to_2d_array(broadcast_named_args["reset_by"]) template_context = merge_dicts( dict( obj=obj, reset_by=reset_by, after_false=after_false, after_reset=after_reset, reset_wait=reset_wait, ), template_context, ) rank_args = substitute_templates(rank_args, template_context, eval_id="rank_args") func = jit_reg.resolve_option(nb.rank_nb, jitted) func = ch_reg.resolve_option(func, chunked) rank = func( mask=obj, rank_func_nb=rank_func_nb, rank_args=rank_args, reset_by=reset_by, after_false=after_false, after_reset=after_reset, reset_wait=reset_wait, ) rank_wrapped = wrapper.wrap(rank, group_by=False, **wrap_kwargs) if as_mapped: rank_wrapped = rank_wrapped.replace(-1, np.nan) return rank_wrapped.vbt.to_mapped(dropna=True, dtype=int_, **kwargs) return rank_wrapped def pos_rank( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, allow_gaps: bool = False, **kwargs, ) -> tp.Union[tp.SeriesFrame, MappedArray]: """Get signal position ranks. Uses `SignalsAccessor.rank` with `vectorbtpro.signals.nb.sig_pos_rank_nb`. Usage: * Rank each True value in each partition in `mask`: ```pycon >>> mask.vbt.signals.pos_rank() a b c 2020-01-01 0 0 0 2020-01-02 -1 -1 1 2020-01-03 -1 0 2 2020-01-04 -1 -1 -1 2020-01-05 -1 0 -1 >>> mask.vbt.signals.pos_rank(after_false=True) a b c 2020-01-01 -1 -1 -1 2020-01-02 -1 -1 -1 2020-01-03 -1 0 -1 2020-01-04 -1 -1 -1 2020-01-05 -1 0 -1 >>> mask.vbt.signals.pos_rank(allow_gaps=True) a b c 2020-01-01 0 0 0 2020-01-02 -1 -1 1 2020-01-03 -1 1 2 2020-01-04 -1 -1 -1 2020-01-05 -1 2 -1 >>> mask.vbt.signals.pos_rank(reset_by=~mask, allow_gaps=True) a b c 2020-01-01 0 0 0 2020-01-02 -1 -1 1 2020-01-03 -1 0 2 2020-01-04 -1 -1 -1 2020-01-05 -1 0 -1 ``` """ chunked = ch.specialize_chunked_option( chunked, arg_take_spec=dict( rank_args=ch.ArgsTaker( None, ) ), ) return self.rank( rank_func_nb=jit_reg.resolve_option(nb.sig_pos_rank_nb, jitted), rank_args=(allow_gaps,), jitted=jitted, chunked=chunked, **kwargs, ) def pos_rank_after( self, reset_by: tp.ArrayLike, after_reset: bool = True, allow_gaps: bool = True, **kwargs, ) -> tp.Union[tp.SeriesFrame, MappedArray]: """Get signal position ranks after each signal in `reset_by`. !!! note `allow_gaps` is enabled by default.""" return self.pos_rank(reset_by=reset_by, after_reset=after_reset, allow_gaps=allow_gaps, **kwargs) def partition_pos_rank( self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> tp.Union[tp.SeriesFrame, MappedArray]: """Get partition position ranks. Uses `SignalsAccessor.rank` with `vectorbtpro.signals.nb.part_pos_rank_nb`. Usage: * Rank each partition of True values in `mask`: ```pycon >>> mask.vbt.signals.partition_pos_rank() a b c 2020-01-01 0 0 0 2020-01-02 -1 -1 0 2020-01-03 -1 1 0 2020-01-04 -1 -1 -1 2020-01-05 -1 2 -1 >>> mask.vbt.signals.partition_pos_rank(after_false=True) a b c 2020-01-01 -1 -1 -1 2020-01-02 -1 -1 -1 2020-01-03 -1 0 -1 2020-01-04 -1 -1 -1 2020-01-05 -1 1 -1 >>> mask.vbt.signals.partition_pos_rank(reset_by=mask) a b c 2020-01-01 0 0 0 2020-01-02 -1 -1 0 2020-01-03 -1 0 0 2020-01-04 -1 -1 -1 2020-01-05 -1 0 -1 ``` """ return self.rank( jit_reg.resolve_option(nb.part_pos_rank_nb, jitted), jitted=jitted, chunked=chunked, **kwargs, ) def partition_pos_rank_after(self, reset_by: tp.ArrayLike, **kwargs) -> tp.Union[tp.SeriesFrame, MappedArray]: """Get partition position ranks after each signal in `reset_by`.""" return self.partition_pos_rank(reset_by=reset_by, after_reset=True, **kwargs) def first( self, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank == 0`. Uses `SignalsAccessor.pos_rank`.""" pos_rank = self.pos_rank(**kwargs).values return self.wrapper.wrap(pos_rank == 0, group_by=False, **resolve_dict(wrap_kwargs)) def first_after( self, reset_by: tp.ArrayLike, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank == 0`. Uses `SignalsAccessor.pos_rank_after`.""" pos_rank = self.pos_rank_after(reset_by, **kwargs).values return self.wrapper.wrap(pos_rank == 0, group_by=False, **resolve_dict(wrap_kwargs)) def nth( self, n: int, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank == n`. Uses `SignalsAccessor.pos_rank`.""" pos_rank = self.pos_rank(**kwargs).values return self.wrapper.wrap(pos_rank == n, group_by=False, **resolve_dict(wrap_kwargs)) def nth_after( self, n: int, reset_by: tp.ArrayLike, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank == n`. Uses `SignalsAccessor.pos_rank_after`.""" pos_rank = self.pos_rank_after(reset_by, **kwargs).values return self.wrapper.wrap(pos_rank == n, group_by=False, **resolve_dict(wrap_kwargs)) def from_nth( self, n: int, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank >= n`. Uses `SignalsAccessor.pos_rank`.""" pos_rank = self.pos_rank(**kwargs).values return self.wrapper.wrap(pos_rank >= n, group_by=False, **resolve_dict(wrap_kwargs)) def from_nth_after( self, n: int, reset_by: tp.ArrayLike, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank >= n`. Uses `SignalsAccessor.pos_rank_after`.""" pos_rank = self.pos_rank_after(reset_by, **kwargs).values return self.wrapper.wrap(pos_rank >= n, group_by=False, **resolve_dict(wrap_kwargs)) def to_nth( self, n: int, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank < n`. Uses `SignalsAccessor.pos_rank`.""" pos_rank = self.pos_rank(**kwargs).values return self.wrapper.wrap(pos_rank < n, group_by=False, **resolve_dict(wrap_kwargs)) def to_nth_after( self, n: int, reset_by: tp.ArrayLike, wrap_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.SeriesFrame: """Select signals that satisfy the condition `pos_rank < n`. Uses `SignalsAccessor.pos_rank_after`.""" pos_rank = self.pos_rank_after(reset_by, **kwargs).values return self.wrapper.wrap(pos_rank < n, group_by=False, **resolve_dict(wrap_kwargs)) def pos_rank_mapped(self, group_by: tp.GroupByLike = None, **kwargs) -> MappedArray: """Get a mapped array of signal position ranks. Uses `SignalsAccessor.pos_rank`.""" return self.pos_rank(as_mapped=True, group_by=group_by, **kwargs) def partition_pos_rank_mapped(self, group_by: tp.GroupByLike = None, **kwargs) -> MappedArray: """Get a mapped array of partition position ranks. Uses `SignalsAccessor.partition_pos_rank`.""" return self.partition_pos_rank(as_mapped=True, group_by=group_by, **kwargs) # ############# Distance ############# # def distance_from_last( self, nth: int = 1, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """See `vectorbtpro.signals.nb.distance_from_last_nb`. Usage: * Get the distance to the last signal: ```pycon >>> mask.vbt.signals.distance_from_last() a b c 2020-01-01 -1 -1 -1 2020-01-02 1 1 1 2020-01-03 2 2 1 2020-01-04 3 1 1 2020-01-05 4 2 2 ``` * Get the distance to the second last signal: ```pycon >>> mask.vbt.signals.distance_from_last(nth=2) a b c 2020-01-01 -1 -1 -1 2020-01-02 -1 -1 1 2020-01-03 -1 2 1 2020-01-04 -1 3 2 2020-01-05 -1 2 3 ``` """ func = jit_reg.resolve_option(nb.distance_from_last_nb, jitted) func = ch_reg.resolve_option(func, chunked) distance_from_last = func(self.to_2d_array(), nth=nth) return self.wrapper.wrap(distance_from_last, group_by=False, **resolve_dict(wrap_kwargs)) # ############# Conversion ############# # def to_mapped( self, group_by: tp.GroupByLike = None, **kwargs, ) -> MappedArray: """Convert this object into an instance of `vectorbtpro.records.mapped_array.MappedArray`.""" mapped_arr = self.to_2d_array().flatten(order="F") col_arr = np.repeat(np.arange(self.wrapper.shape_2d[1]), self.wrapper.shape_2d[0]) idx_arr = np.tile(np.arange(self.wrapper.shape_2d[0]), self.wrapper.shape_2d[1]) new_mapped_arr = mapped_arr[mapped_arr] new_col_arr = col_arr[mapped_arr] new_idx_arr = idx_arr[mapped_arr] return MappedArray( wrapper=self.wrapper, mapped_arr=new_mapped_arr, col_arr=new_col_arr, idx_arr=new_idx_arr, **kwargs, ).regroup(group_by) # ############# Relation ############# # def get_relation_str(self, relation: tp.Union[int, str]) -> str: """Get direction string for `relation`.""" if isinstance(relation, str): relation = map_enum_fields(relation, enums.SignalRelation) if relation == enums.SignalRelation.OneOne: return ">-<" if relation == enums.SignalRelation.OneMany: return "->" if relation == enums.SignalRelation.ManyOne: return "<-" if relation == enums.SignalRelation.ManyMany: return "<->" raise ValueError(f"Invalid relation: {relation}") # ############# Ranges ############# # def delta_ranges( self, delta: tp.Union[str, int, tp.FrequencyLike], group_by: tp.GroupByLike = None, **kwargs, ) -> Ranges: """Build a record array of the type `vectorbtpro.generic.ranges.Ranges` from a delta applied after each signal (or before if delta is negative).""" return Ranges.from_delta(self.to_mapped(), delta=delta, **kwargs).regroup(group_by) def between_ranges( self, target: tp.Optional[tp.ArrayLike] = None, relation: tp.Union[int, str] = "onemany", incl_open: bool = False, broadcast_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, attach_target: bool = False, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> Ranges: """Wrap the result of `vectorbtpro.signals.nb.between_ranges_nb` with `vectorbtpro.generic.ranges.Ranges`. If `target` specified, see `vectorbtpro.signals.nb.between_two_ranges_nb`. Both will broadcast using `vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`. Usage: * One array: ```pycon >>> mask_sr = pd.Series([True, False, False, True, False, True, True]) >>> ranges = mask_sr.vbt.signals.between_ranges() >>> ranges >>> ranges.readable Range Id Column Start Index End Index Status 0 0 0 0 3 Closed 1 1 0 3 5 Closed 2 2 0 5 6 Closed >>> ranges.duration.values array([3, 2, 1]) ``` * Two arrays, traversing the signals of the first array: ```pycon >>> mask_sr1 = pd.Series([True, True, True, False, False]) >>> mask_sr2 = pd.Series([False, False, True, False, True]) >>> ranges = mask_sr1.vbt.signals.between_ranges(target=mask_sr2) >>> ranges >>> ranges.readable Range Id Column Start Index End Index Status 0 0 0 2 2 Closed 1 1 0 2 4 Closed >>> ranges.duration.values array([0, 2]) ``` * Two arrays, traversing the signals of the second array: ```pycon >>> ranges = mask_sr1.vbt.signals.between_ranges(target=mask_sr2, relation="manyone") >>> ranges >>> ranges.readable Range Id Column Start Index End Index Status 0 0 0 0 2 Closed 1 1 0 1 2 Closed 2 2 0 2 2 Closed >>> ranges.duration.values array([0, 2]) ``` """ if broadcast_kwargs is None: broadcast_kwargs = {} if isinstance(relation, str): relation = map_enum_fields(relation, enums.SignalRelation) if target is None: func = jit_reg.resolve_option(nb.between_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) range_records = func(self.to_2d_array(), incl_open=incl_open) wrapper = self.wrapper to_attach = self.obj else: broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) broadcasted_args, wrapper = reshaping.broadcast( dict(obj=self.obj, target=target), return_wrapper=True, **broadcast_kwargs, ) func = jit_reg.resolve_option(nb.between_two_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) range_records = func( broadcasted_args["obj"], broadcasted_args["target"], relation=relation, incl_open=incl_open, ) to_attach = broadcasted_args["target"] if attach_target else broadcasted_args["obj"] kwargs = merge_dicts(dict(close=to_attach), kwargs) return Ranges.from_records(wrapper, range_records, **kwargs).regroup(group_by) def partition_ranges( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> Ranges: """Wrap the result of `vectorbtpro.signals.nb.partition_ranges_nb` with `vectorbtpro.generic.ranges.Ranges`. If `use_end_idxs` is True, uses the index of the last signal in each partition as `idx_arr`. Otherwise, uses the index of the first signal. Usage: ```pycon >>> mask_sr = pd.Series([True, True, True, False, True, True]) >>> mask_sr.vbt.signals.partition_ranges().readable Range Id Column Start Timestamp End Timestamp Status 0 0 0 0 3 Closed 1 1 0 4 5 Open ``` """ func = jit_reg.resolve_option(nb.partition_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) range_records = func(self.to_2d_array()) kwargs = merge_dicts(dict(close=self.obj), kwargs) return Ranges.from_records(self.wrapper, range_records, **kwargs).regroup(group_by) def between_partition_ranges( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs, ) -> Ranges: """Wrap the result of `vectorbtpro.signals.nb.between_partition_ranges_nb` with `vectorbtpro.generic.ranges.Ranges`. Usage: ```pycon >>> mask_sr = pd.Series([True, False, False, True, False, True, True]) >>> mask_sr.vbt.signals.between_partition_ranges().readable Range Id Column Start Timestamp End Timestamp Status 0 0 0 0 3 Closed 1 1 0 3 5 Closed ``` """ func = jit_reg.resolve_option(nb.between_partition_ranges_nb, jitted) func = ch_reg.resolve_option(func, chunked) range_records = func(self.to_2d_array()) kwargs = merge_dicts(dict(close=self.obj), kwargs) return Ranges.from_records(self.wrapper, range_records, **kwargs).regroup(group_by) # ############# Raveling ############# # @classmethod def index_from_unravel( cls, range_: tp.Array1d, row_idxs: tp.Array1d, index: tp.Index, signal_index_type: str = "range", signal_index_name: str = "signal", ): """Get index from an unraveling operation.""" if signal_index_type.lower() == "range": return pd.Index(range_, name=signal_index_name) if signal_index_type.lower() in ("position", "positions"): return pd.Index(row_idxs, name=signal_index_name) if signal_index_type.lower() in ("label", "labels"): if -1 in row_idxs: raise ValueError("Some columns have no signals. Use other signal index types.") return pd.Index(index[row_idxs], name=signal_index_name) raise ValueError(f"Invalid signal_index_type: '{signal_index_type}'") def unravel( self, incl_empty_cols: bool = True, force_signal_index: bool = False, signal_index_type: str = "range", signal_index_name: str = "signal", jitted: tp.JittedOption = None, clean_index_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeTuple[tp.SeriesFrame]: """Unravel signals. See `vectorbtpro.signals.nb.unravel_nb`. Argument `signal_index_type` takes the following values: * "range": Basic signal counter in a column * "position(s)": Integer position (row) of signal in a column * "label(s)": Label of signal in a column Usage: ```pycon >>> mask.vbt.signals.unravel() signal 0 0 1 2 0 1 2 a b b b c c c 2020-01-01 True True False False True False False 2020-01-02 False False False False False True False 2020-01-03 False False True False False False True 2020-01-04 False False False False False False False 2020-01-05 False False False True False False False ``` """ if clean_index_kwargs is None: clean_index_kwargs = {} if wrap_kwargs is None: wrap_kwargs = {} func = jit_reg.resolve_option(nb.unravel_nb, jitted) new_mask, range_, row_idxs, col_idxs = func(self.to_2d_array(), incl_empty_cols=incl_empty_cols) if new_mask.shape == self.wrapper.shape_2d and incl_empty_cols and not force_signal_index: return self.wrapper.wrap(new_mask) if not incl_empty_cols and (row_idxs == -1).all(): raise ValueError("No columns left") signal_index = self.index_from_unravel( range_, row_idxs, self.wrapper.index, signal_index_type=signal_index_type, signal_index_name=signal_index_name, ) new_columns = indexes.stack_indexes((signal_index, self.wrapper.columns[col_idxs]), **clean_index_kwargs) return self.wrapper.wrap(new_mask, columns=new_columns, group_by=False, **wrap_kwargs) @hybrid_method def unravel_between( cls_or_self, *objs, relation: tp.Union[int, str] = "onemany", incl_open_source: bool = False, incl_open_target: bool = False, incl_empty_cols: bool = True, broadcast_kwargs: tp.KwargsLike = None, force_signal_index: bool = False, signal_index_type: str = "pair_range", signal_index_name: str = "signal", jitted: tp.JittedOption = None, clean_index_kwargs: tp.KwargsLike = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeTuple[tp.SeriesFrame]: """Unravel signal pairs. If one array is passed, see `vectorbtpro.signals.nb.unravel_between_nb`. If two arrays are passed, see `vectorbtpro.signals.nb.unravel_between_two_nb`. Argument `signal_index_type` takes the following values: * "pair_range": Basic pair counter in a column * "range": Basic signal counter in a column * "source_range": Basic signal counter in a source column * "target_range": Basic signal counter in a target column * "position(s)": Integer position (row) of signal in a column * "source_position(s)": Integer position (row) of signal in a source column * "target_position(s)": Integer position (row) of signal in a target column * "label(s)": Label of signal in a column * "source_label(s)": Label of signal in a source column * "target_label(s)": Label of signal in a target column Usage: * One mask: ```pycon >>> mask.vbt.signals.unravel_between() signal -1 0 1 0 1 a b b c c 2020-01-01 False True False True False 2020-01-02 False False False True True 2020-01-03 False True True False True 2020-01-04 False False False False False 2020-01-05 False False True False False >>> mask.vbt.signals.unravel_between(signal_index_type="position") source_signal -1 0 2 0 1 target_signal -1 2 4 1 2 a b b c c 2020-01-01 False True False True False 2020-01-02 False False False True True 2020-01-03 False True True False True 2020-01-04 False False False False False 2020-01-05 False False True False False ``` * Two masks: ```pycon >>> source_mask = pd.Series([True, True, False, False, True, True]) >>> target_mask = pd.Series([False, False, True, True, False, False]) >>> new_source_mask, new_target_mask = vbt.pd_acc.signals.unravel_between( ... source_mask, ... target_mask ... ) >>> new_source_mask signal 0 1 0 False False 1 True True 2 False False 3 False False 4 False False 5 False False >>> new_target_mask signal 0 1 0 False False 1 False False 2 True False 3 False True 4 False False 5 False False >>> new_source_mask, new_target_mask = vbt.pd_acc.signals.unravel_between( ... source_mask, ... target_mask, ... relation="chain" ... ) >>> new_source_mask signal 0 1 0 True False 1 False False 2 False False 3 False False 4 False True 5 False False >>> new_target_mask signal 0 1 0 False False 1 False False 2 True True 3 False False 4 False False 5 False False ``` """ if broadcast_kwargs is None: broadcast_kwargs = {} if clean_index_kwargs is None: clean_index_kwargs = {} if wrap_kwargs is None: wrap_kwargs = {} if isinstance(relation, str): relation = map_enum_fields(relation, enums.SignalRelation) signal_index_type = signal_index_type.lower() if not isinstance(cls_or_self, type): objs = (cls_or_self.obj, *objs) def _build_new_columns( source_range, target_range, source_idxs, target_idxs, col_idxs, ): indexes_to_stack = [] if signal_index_type == "pair_range": one_points = np.concatenate((np.array([0]), col_idxs[1:] - col_idxs[:-1])) basic_range = np.arange(len(col_idxs)) range_points = np.where(one_points == 1, basic_range, one_points) signal_range = basic_range - np.maximum.accumulate(range_points) signal_range[(source_range == -1) & (target_range == -1)] = -1 indexes_to_stack.append(pd.Index(signal_range, name=signal_index_name)) else: if not signal_index_type.startswith("target_"): indexes_to_stack.append( cls_or_self.index_from_unravel( source_range, source_idxs, wrapper.index, signal_index_type=signal_index_type.replace("source_", ""), signal_index_name="source_" + signal_index_name, ) ) if not signal_index_type.startswith("source_"): indexes_to_stack.append( cls_or_self.index_from_unravel( target_range, target_idxs, wrapper.index, signal_index_type=signal_index_type.replace("target_", ""), signal_index_name="target_" + signal_index_name, ) ) if len(indexes_to_stack) == 1: indexes_to_stack[0] = indexes_to_stack[0].rename(signal_index_name) return indexes.stack_indexes((*indexes_to_stack, wrapper.columns[col_idxs]), **clean_index_kwargs) if len(objs) == 1: obj = objs[0] wrapper = ArrayWrapper.from_obj(obj) if not isinstance(obj, (pd.Series, pd.DataFrame)): obj = wrapper.wrap(obj) func = jit_reg.resolve_option(nb.unravel_between_nb, jitted) new_mask, source_range, target_range, source_idxs, target_idxs, col_idxs = func( reshaping.to_2d_array(obj), incl_open_source=incl_open_source, incl_empty_cols=incl_empty_cols, ) if new_mask.shape == wrapper.shape_2d and incl_empty_cols and not force_signal_index: return wrapper.wrap(new_mask) if not incl_empty_cols and (source_idxs == -1).all(): raise ValueError("No columns left") new_columns = _build_new_columns( source_range, target_range, source_idxs, target_idxs, col_idxs, ) return wrapper.wrap(new_mask, columns=new_columns, group_by=False, **wrap_kwargs) if len(objs) == 2: source = objs[0] target = objs[1] broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs) broadcasted_args, wrapper = reshaping.broadcast( dict(source=source, target=target), return_wrapper=True, **broadcast_kwargs, ) func = jit_reg.resolve_option(nb.unravel_between_two_nb, jitted) new_source_mask, new_target_mask, source_range, target_range, source_idxs, target_idxs, col_idxs = func( broadcasted_args["source"], broadcasted_args["target"], relation=relation, incl_open_source=incl_open_source, incl_open_target=incl_open_target, incl_empty_cols=incl_empty_cols, ) if new_source_mask.shape == wrapper.shape_2d and incl_empty_cols and not force_signal_index: return wrapper.wrap(new_source_mask), wrapper.wrap(new_target_mask) if not incl_empty_cols and (source_idxs == -1).all() and (target_idxs == -1).all(): raise ValueError("No columns left") new_columns = _build_new_columns( source_range, target_range, source_idxs, target_idxs, col_idxs, ) new_source_mask = wrapper.wrap(new_source_mask, columns=new_columns, group_by=False, **wrap_kwargs) new_target_mask = wrapper.wrap(new_target_mask, columns=new_columns, group_by=False, **wrap_kwargs) return new_source_mask, new_target_mask raise ValueError("This method accepts either one or two arrays") def ravel( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.SeriesFrame: """Ravel signals. See `vectorbtpro.signals.nb.ravel_nb`. Usage: ```pycon >>> unravel_mask = mask.vbt.signals.unravel() >>> original_mask = unravel_mask.vbt.signals.ravel(group_by=vbt.ExceptLevel("signal")) >>> original_mask a b c 2020-01-01 True True True 2020-01-02 False False True 2020-01-03 False True True 2020-01-04 False False False 2020-01-05 False True False ``` """ if wrap_kwargs is None: wrap_kwargs = {} group_map = self.wrapper.grouper.get_group_map(group_by=group_by) func = jit_reg.resolve_option(nb.ravel_nb, jitted) new_mask = func(self.to_2d_array(), group_map) return self.wrapper.wrap(new_mask, group_by=group_by, **wrap_kwargs) # ############# Index ############# # def nth_index( self, n: int, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """See `vectorbtpro.signals.nb.nth_index_nb`. Usage: ```pycon >>> mask.vbt.signals.nth_index(0) a 2020-01-01 b 2020-01-01 c 2020-01-01 Name: nth_index, dtype: datetime64[ns] >>> mask.vbt.signals.nth_index(2) a NaT b 2020-01-05 c 2020-01-03 Name: nth_index, dtype: datetime64[ns] >>> mask.vbt.signals.nth_index(-1) a 2020-01-01 b 2020-01-05 c 2020-01-03 Name: nth_index, dtype: datetime64[ns] >>> mask.vbt.signals.nth_index(-1, group_by=True) Timestamp('2020-01-05 00:00:00') ``` """ if self.is_frame() and self.wrapper.grouper.is_grouped(group_by=group_by): squeezed = self.squeeze_grouped( jit_reg.resolve_option(generic_nb.any_reduce_nb, jitted), group_by=group_by, jitted=jitted, chunked=chunked, ) arr = reshaping.to_2d_array(squeezed) else: arr = self.to_2d_array() func = jit_reg.resolve_option(nb.nth_index_nb, jitted) func = ch_reg.resolve_option(func, chunked) nth_index = func(arr, n) wrap_kwargs = merge_dicts(dict(name_or_index="nth_index", to_index=True), wrap_kwargs) return self.wrapper.wrap_reduced(nth_index, group_by=group_by, **wrap_kwargs) def norm_avg_index( self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, wrap_kwargs: tp.KwargsLike = None, ) -> tp.MaybeSeries: """See `vectorbtpro.signals.nb.norm_avg_index_nb`. Normalized average index measures the average signal location relative to the middle of the column. This way, we can quickly see where the majority of signals are located. Common values are: * -1.0: only the first signal is set * 1.0: only the last signal is set * 0.0: symmetric distribution around the middle * [-1.0, 0.0): average signal is on the left * (0.0, 1.0]: average signal is on the right Usage: ```pycon >>> pd.Series([True, False, False, False]).vbt.signals.norm_avg_index() -1.0 >>> pd.Series([False, False, False, True]).vbt.signals.norm_avg_index() 1.0 >>> pd.Series([True, False, False, True]).vbt.signals.norm_avg_index() 0.0 ``` """ if self.is_frame() and self.wrapper.grouper.is_grouped(group_by=group_by): group_lens = self.wrapper.grouper.get_group_lens(group_by=group_by) func = jit_reg.resolve_option(nb.norm_avg_index_grouped_nb, jitted) func = ch_reg.resolve_option(func, chunked) norm_avg_index = func(self.to_2d_array(), group_lens) else: func = jit_reg.resolve_option(nb.norm_avg_index_nb, jitted) func = ch_reg.resolve_option(func, chunked) norm_avg_index = func(self.to_2d_array()) wrap_kwargs = merge_dicts(dict(name_or_index="norm_avg_index"), wrap_kwargs) return self.wrapper.wrap_reduced(norm_avg_index, group_by=group_by, **wrap_kwargs) def index_mapped(self, group_by: tp.GroupByLike = None, **kwargs) -> MappedArray: """Get a mapped array of indices. See `vectorbtpro.generic.accessors.GenericAccessor.to_mapped`. Only True values will be considered.""" indices = np.arange(len(self.wrapper.index), dtype=float_)[:, None] indices = np.tile(indices, (1, len(self.wrapper.columns))) indices = reshaping.soft_to_ndim(indices, self.wrapper.ndim) indices[~self.obj.values] = np.nan return self.wrapper.wrap(indices).vbt.to_mapped(dropna=True, dtype=int_, group_by=group_by, **kwargs) def total(self, wrap_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None) -> tp.MaybeSeries: """Total number of True values in each column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="total"), wrap_kwargs) return self.sum(group_by=group_by, wrap_kwargs=wrap_kwargs) def rate(self, wrap_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, **kwargs) -> tp.MaybeSeries: """`SignalsAccessor.total` divided by the total index length in each column/group.""" total = reshaping.to_1d_array(self.total(group_by=group_by, **kwargs)) wrap_kwargs = merge_dicts(dict(name_or_index="rate"), wrap_kwargs) total_steps = self.wrapper.grouper.get_group_lens(group_by=group_by) * self.wrapper.shape[0] return self.wrapper.wrap_reduced(total / total_steps, group_by=group_by, **wrap_kwargs) def total_partitions( self, wrap_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, **kwargs, ) -> tp.MaybeSeries: """Total number of partitions of True values in each column/group.""" wrap_kwargs = merge_dicts(dict(name_or_index="total_partitions"), wrap_kwargs) return self.partition_ranges(**kwargs).count(group_by=group_by, wrap_kwargs=wrap_kwargs) def partition_rate( self, wrap_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, **kwargs, ) -> tp.MaybeSeries: """`SignalsAccessor.total_partitions` divided by `SignalsAccessor.total` in each column/group.""" total_partitions = reshaping.to_1d_array(self.total_partitions(group_by=group_by, *kwargs)) total = reshaping.to_1d_array(self.total(group_by=group_by, *kwargs)) wrap_kwargs = merge_dicts(dict(name_or_index="partition_rate"), wrap_kwargs) return self.wrapper.wrap_reduced(total_partitions / total, group_by=group_by, **wrap_kwargs) # ############# Stats ############# # @property def stats_defaults(self) -> tp.Kwargs: """Defaults for `SignalsAccessor.stats`. Merges `vectorbtpro.generic.accessors.GenericAccessor.stats_defaults` and `stats` from `vectorbtpro._settings.signals`.""" from vectorbtpro._settings import settings signals_stats_cfg = settings["signals"]["stats"] return merge_dicts(GenericAccessor.stats_defaults.__get__(self), signals_stats_cfg) _metrics: tp.ClassVar[Config] = HybridConfig( dict( start_index=dict( title="Start Index", calc_func=lambda self: self.wrapper.index[0], agg_func=None, tags="wrapper", ), end_index=dict( title="End Index", calc_func=lambda self: self.wrapper.index[-1], agg_func=None, tags="wrapper", ), total_duration=dict( title="Total Duration", calc_func=lambda self: len(self.wrapper.index), apply_to_timedelta=True, agg_func=None, tags="wrapper", ), total=dict(title="Total", calc_func="total", tags="signals"), rate=dict( title="Rate [%]", calc_func="rate", post_calc_func=lambda self, out, settings: out * 100, tags="signals", ), total_overlapping=dict( title="Total Overlapping", calc_func=lambda self, target, group_by: (self & target).vbt.signals.total(group_by=group_by), check_silent_has_target=True, tags=["signals", "target"], ), overlapping_rate=dict( title="Overlapping Rate [%]", calc_func=lambda self, target, group_by: (self & target).vbt.signals.total(group_by=group_by) / (self | target).vbt.signals.total(group_by=group_by), post_calc_func=lambda self, out, settings: out * 100, check_silent_has_target=True, tags=["signals", "target"], ), first_index=dict( title="First Index", calc_func="nth_index", n=0, wrap_kwargs=dict(to_index=True), tags=["signals", "index"], ), last_index=dict( title="Last Index", calc_func="nth_index", n=-1, wrap_kwargs=dict(to_index=True), tags=["signals", "index"], ), norm_avg_index=dict(title="Norm Avg Index [-1, 1]", calc_func="norm_avg_index", tags=["signals", "index"]), distance=dict( title=RepEval( "f'Distance {self.get_relation_str(relation)} {target_name}' if target is not None else 'Distance'" ), calc_func="between_ranges.duration", post_calc_func=lambda self, out, settings: { "Min": out.min(), "Median": out.median(), "Max": out.max(), }, apply_to_timedelta=True, tags=RepEval("['signals', 'distance', 'target'] if target is not None else ['signals', 'distance']"), ), total_partitions=dict( title="Total Partitions", calc_func="total_partitions", tags=["signals", "partitions"], ), partition_rate=dict( title="Partition Rate [%]", calc_func="partition_rate", post_calc_func=lambda self, out, settings: out * 100, tags=["signals", "partitions"], ), partition_len=dict( title="Partition Length", calc_func="partition_ranges.duration", post_calc_func=lambda self, out, settings: { "Min": out.min(), "Median": out.median(), "Max": out.max(), }, apply_to_timedelta=True, tags=["signals", "partitions", "distance"], ), partition_distance=dict( title="Partition Distance", calc_func="between_partition_ranges.duration", post_calc_func=lambda self, out, settings: { "Min": out.min(), "Median": out.median(), "Max": out.max(), }, apply_to_timedelta=True, tags=["signals", "partitions", "distance"], ), ) ) @property def metrics(self) -> Config: return self._metrics # ############# Plotting ############# # def plot( self, yref: str = "y", column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot signals. Args: yref (str): Y coordinate axis. column (hashable): Column to plot. **kwargs: Keyword arguments passed to `vectorbtpro.generic.accessors.GenericAccessor.lineplot`. Usage: ```pycon >>> mask[['a', 'c']].vbt.signals.plot().show() ``` ![](/assets/images/api/signals_df_plot.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/signals_df_plot.dark.svg#only-dark){: .iimg loading=lazy } """ if column is not None: _self = self.select_col(column=column) else: _self = self default_kwargs = dict(trace_kwargs=dict(line=dict(shape="hv"))) default_kwargs["yaxis" + yref[1:]] = dict(tickmode="array", tickvals=[0, 1], ticktext=["false", "true"]) return _self.obj.vbt.lineplot(**merge_dicts(default_kwargs, kwargs)) def plot_as_markers( self, y: tp.Optional[tp.ArrayLike] = None, column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot Series as markers. Args: y (array_like): Y-axis values to plot markers on. column (hashable): Column to plot. **kwargs: Keyword arguments passed to `vectorbtpro.generic.accessors.GenericAccessor.scatterplot`. Usage: ```pycon >>> ts = pd.Series([1, 2, 3, 2, 1], index=mask.index) >>> fig = ts.vbt.lineplot() >>> mask['b'].vbt.signals.plot_as_entries(y=ts, fig=fig) >>> (~mask['b']).vbt.signals.plot_as_exits(y=ts, fig=fig).show() ``` ![](/assets/images/api/signals_plot_as_markers.light.svg#only-light){: .iimg loading=lazy } ![](/assets/images/api/signals_plot_as_markers.dark.svg#only-dark){: .iimg loading=lazy } """ from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] obj = self.obj if isinstance(obj, pd.DataFrame): obj = self.select_col_from_obj(obj, column=column) if y is None: y = pd.Series.vbt.empty_like(obj, 1) else: y = reshaping.to_pd_array(y) if isinstance(y, pd.DataFrame): y = self.select_col_from_obj(y, column=column) obj, y = reshaping.broadcast(obj, y, columns_from="keep") obj = obj.fillna(False).astype(np.bool_) if y.name is None: y = y.rename("Y") def_kwargs = dict( trace_kwargs=dict( marker=dict( symbol="circle", color=plotting_cfg["contrast_color_schema"]["blue"], size=7, ), name=obj.name, ) ) kwargs = merge_dicts(def_kwargs, kwargs) if "marker_color" in kwargs["trace_kwargs"]: marker_color = kwargs["trace_kwargs"]["marker_color"] else: marker_color = kwargs["trace_kwargs"]["marker"]["color"] if isinstance(marker_color, str) and "rgba" not in marker_color: line_color = adjust_lightness(marker_color) else: line_color = marker_color kwargs = merge_dicts( dict( trace_kwargs=dict( marker=dict( line=dict(width=1, color=line_color), ), ), ), kwargs, ) return y[obj].vbt.scatterplot(**kwargs) def plot_as_entries( self, y: tp.Optional[tp.ArrayLike] = None, column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot signals as entry markers. See `SignalsSRAccessor.plot_as_markers`.""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] return self.plot_as_markers( y=y, column=column, **merge_dicts( dict( trace_kwargs=dict( marker=dict( symbol="triangle-up", color=plotting_cfg["contrast_color_schema"]["green"], size=8, ), name="Entries", ) ), kwargs, ), ) def plot_as_exits( self, y: tp.Optional[tp.ArrayLike] = None, column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot signals as exit markers. See `SignalsSRAccessor.plot_as_markers`.""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] return self.plot_as_markers( y=y, column=column, **merge_dicts( dict( trace_kwargs=dict( marker=dict( symbol="triangle-down", color=plotting_cfg["contrast_color_schema"]["red"], size=8, ), name="Exits", ) ), kwargs, ), ) def plot_as_entry_marks( self, y: tp.Optional[tp.ArrayLike] = None, column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot signals as marked entry markers. See `SignalsSRAccessor.plot_as_markers`.""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] return self.plot_as_markers( y=y, column=column, **merge_dicts( dict( trace_kwargs=dict( marker=dict( symbol="circle", color="rgba(0, 0, 0, 0)", size=20, line=dict( color=plotting_cfg["contrast_color_schema"]["green"], width=2, ), ), name="Entry marks", ) ), kwargs, ), ) def plot_as_exit_marks( self, y: tp.Optional[tp.ArrayLike] = None, column: tp.Optional[tp.Label] = None, **kwargs, ) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]: """Plot signals as marked exit markers. See `SignalsSRAccessor.plot_as_markers`.""" from vectorbtpro._settings import settings plotting_cfg = settings["plotting"] return self.plot_as_markers( y=y, column=column, **merge_dicts( dict( trace_kwargs=dict( marker=dict( symbol="circle", color="rgba(0, 0, 0, 0)", size=20, line=dict( color=plotting_cfg["contrast_color_schema"]["red"], width=2, ), ), name="Exit marks", ) ), kwargs, ), ) @property def plots_defaults(self) -> tp.Kwargs: """Defaults for `SignalsAccessor.plots`. Merges `vectorbtpro.generic.accessors.GenericAccessor.plots_defaults` and `plots` from `vectorbtpro._settings.signals`.""" from vectorbtpro._settings import settings signals_plots_cfg = settings["signals"]["plots"] return merge_dicts(GenericAccessor.plots_defaults.__get__(self), signals_plots_cfg) @property def subplots(self) -> Config: return self._subplots SignalsAccessor.override_metrics_doc(__pdoc__) SignalsAccessor.override_subplots_doc(__pdoc__) @register_sr_vbt_accessor("signals") class SignalsSRAccessor(SignalsAccessor, GenericSRAccessor): """Accessor on top of signal series. For Series only. Accessible via `pd.Series.vbt.signals`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, _full_init: bool = True, **kwargs, ) -> None: GenericSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: SignalsAccessor.__init__(self, wrapper, obj=obj, **kwargs) @register_df_vbt_accessor("signals") class SignalsDFAccessor(SignalsAccessor, GenericDFAccessor): """Accessor on top of signal series. For DataFrames only. Accessible via `pd.DataFrame.vbt.signals`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, _full_init: bool = True, **kwargs, ) -> None: GenericDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs) if _full_init: SignalsAccessor.__init__(self, wrapper, obj=obj, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Named tuples and enumerated types for signals. Defines enums and other schemas for `vectorbtpro.signals`.""" from vectorbtpro import _typing as tp from vectorbtpro.utils.formatting import prettify __pdoc__all__ = __all__ = [ "StopType", "SignalRelation", "FactoryMode", "GenEnContext", "GenExContext", "GenEnExContext", "RankContext", ] __pdoc__ = {} # ############# Enums ############# # class StopTypeT(tp.NamedTuple): SL: int = 0 TSL: int = 1 TTP: int = 2 TP: int = 3 TD: int = 4 DT: int = 5 StopType = StopTypeT() """_""" __pdoc__[ "StopType" ] = f"""Stop type. ```python {prettify(StopType)} ``` """ class SignalRelationT(tp.NamedTuple): OneOne: int = 0 OneMany: int = 1 ManyOne: int = 2 ManyMany: int = 3 Chain: int = 4 AnyChain: int = 5 SignalRelation = SignalRelationT() """_""" __pdoc__[ "SignalRelation" ] = f"""SignalRelation between two masks. ```python {prettify(SignalRelation)} ``` Attributes: OneOne: One source signal maps to exactly one succeeding target signal. OneMany: One source signal can map to one or more succeeding target signals. ManyOne: One or more source signals can map to exactly one succeeding target signal. ManyMany: One or more source signals can map to one or more succeeding target signals. Chain: First source signal maps to the first target signal after it and vice versa. AnyChain: First signal maps to the first opposite signal after it and vice versa. """ class FactoryModeT(tp.NamedTuple): Entries: int = 0 Exits: int = 1 Both: int = 2 Chain: int = 3 FactoryMode = FactoryModeT() """_""" __pdoc__[ "FactoryMode" ] = f"""Factory mode. ```python {prettify(FactoryMode)} ``` Attributes: Entries: Generate entries only using `generate_func_nb`. Takes no input signal arrays. Produces one output signal array - `entries`. Such generators often have no suffix. Exits: Generate exits only using `generate_ex_func_nb`. Takes one input signal array - `entries`. Produces one output signal array - `exits`. Such generators often have suffix 'X'. Both: Generate both entries and exits using `generate_enex_func_nb`. Takes no input signal arrays. Produces two output signal arrays - `entries` and `exits`. Such generators often have suffix 'NX'. Chain: Generate chain of entries and exits using `generate_enex_func_nb`. Takes one input signal array - `entries`. Produces two output signal arrays - `new_entries` and `exits`. Such generators often have suffix 'CX'. """ # ############# Named tuples ############# # class GenEnContext(tp.NamedTuple): target_shape: tp.Shape only_once: bool wait: int entries_out: tp.Array2d out: tp.Array1d from_i: int to_i: int col: int __pdoc__["GenEnContext"] = "Context of an entry signal generator." __pdoc__["GenEnContext.target_shape"] = "Target shape." __pdoc__["GenEnContext.only_once"] = "Whether to run the placement function only once." __pdoc__["GenEnContext.wait"] = "Number of ticks to wait before placing the next entry." __pdoc__["GenEnContext.entries_out"] = "Output array with entries." __pdoc__["GenEnContext.out"] = "Current segment of the output array with entries." __pdoc__["GenEnContext.from_i"] = "Start index of the segment (inclusive)." __pdoc__["GenEnContext.to_i"] = "End index of the segment (exclusive)." __pdoc__["GenEnContext.col"] = "Column of the segment." class GenExContext(tp.NamedTuple): entries: tp.Array2d until_next: bool skip_until_exit: bool exits_out: tp.Array2d out: tp.Array1d wait: int from_i: int to_i: int col: int __pdoc__["GenExContext"] = "Context of an exit signal generator." __pdoc__["GenExContext.entries"] = "Input array with entries." __pdoc__["GenExContext.until_next"] = "Whether to place signals up to the next entry signal." __pdoc__["GenExContext.skip_until_exit"] = "Whether to skip processing entry signals until the next exit." __pdoc__["GenExContext.exits_out"] = "Output array with exits." __pdoc__["GenExContext.out"] = "Current segment of the output array with exits." __pdoc__["GenExContext.wait"] = "Number of ticks to wait before placing exits." __pdoc__["GenExContext.from_i"] = "Start index of the segment (inclusive)." __pdoc__["GenExContext.to_i"] = "End index of the segment (exclusive)." __pdoc__["GenExContext.col"] = "Column of the segment." class GenEnExContext(tp.NamedTuple): target_shape: tp.Shape entry_wait: int exit_wait: int entries_out: tp.Array2d exits_out: tp.Array2d entries_turn: bool wait: int out: tp.Array1d from_i: int to_i: int col: int __pdoc__["GenExContext"] = "Context of an entry/exit signal generator." __pdoc__["GenExContext.target_shape"] = "Target shape." __pdoc__["GenExContext.entry_wait"] = "Number of ticks to wait before placing entries." __pdoc__["GenExContext.exit_wait"] = "Number of ticks to wait before placing exits." __pdoc__["GenExContext.entries_out"] = "Output array with entries." __pdoc__["GenExContext.exits_out"] = "Output array with exits." __pdoc__["GenExContext.entries_turn"] = "Whether the current turn is generating an entry." __pdoc__["GenExContext.out"] = "Current segment of the output array with entries/exits." __pdoc__["GenExContext.wait"] = "Number of ticks to wait before placing entries/exits." __pdoc__["GenExContext.from_i"] = "Start index of the segment (inclusive)." __pdoc__["GenExContext.to_i"] = "End index of the segment (exclusive)." __pdoc__["GenExContext.col"] = "Column of the segment." class RankContext(tp.NamedTuple): mask: tp.Array2d reset_by: tp.Optional[tp.Array2d] after_false: bool after_reset: bool reset_wait: int col: int i: int last_false_i: int last_reset_i: int all_sig_cnt: int all_part_cnt: int all_sig_in_part_cnt: int nonres_sig_cnt: int nonres_part_cnt: int nonres_sig_in_part_cnt: int sig_cnt: int part_cnt: int sig_in_part_cnt: int __pdoc__["RankContext"] = "Context of a ranker." __pdoc__["RankContext.mask"] = "Source mask." __pdoc__["RankContext.reset_by"] = "Resetting mask." __pdoc__["RankContext.after_false"] = ( """Whether to disregard the first partition of True values if there is no False value before them.""" ) __pdoc__["RankContext.after_reset"] = ( """Whether to disregard the first partition of True values coming before the first reset signal.""" ) __pdoc__["RankContext.reset_wait"] = """Number of ticks to wait before resetting the current partition.""" __pdoc__["RankContext.col"] = "Current column." __pdoc__["RankContext.i"] = "Current row." __pdoc__["RankContext.last_false_i"] = "Row of the last False value in the main mask." __pdoc__[ "RankContext.last_reset_i" ] = """Row of the last True value in the resetting mask. Doesn't take into account `reset_wait`.""" __pdoc__["RankContext.all_sig_cnt"] = """Number of all signals encountered including this.""" __pdoc__["RankContext.all_part_cnt"] = """Number of all partitions encountered including this.""" __pdoc__["RankContext.all_sig_in_part_cnt"] = ( """Number of signals encountered in the current partition including this.""" ) __pdoc__["RankContext.nonres_sig_cnt"] = """Number of non-resetting signals encountered including this.""" __pdoc__["RankContext.nonres_part_cnt"] = """Number of non-resetting partitions encountered including this.""" __pdoc__["RankContext.nonres_sig_in_part_cnt"] = ( """Number of signals encountered in the current non-resetting partition including this.""" ) __pdoc__["RankContext.sig_cnt"] = """Number of valid and resetting signals encountered including this.""" __pdoc__["RankContext.part_cnt"] = """Number of valid and resetting partitions encountered including this.""" __pdoc__["RankContext.sig_in_part_cnt"] = ( """Number of signals encountered in the current valid and resetting partition including this.""" ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Factory for building signal generators. The signal factory class `SignalFactory` extends `vectorbtpro.indicators.factory.IndicatorFactory` to offer a convenient way to create signal generators of any complexity. By providing it with information such as entry and exit functions and the names of inputs, parameters, and outputs, it will create a stand-alone class capable of generating signals for an arbitrary combination of inputs and parameters. """ import inspect import numpy as np from numba import njit from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import combining from vectorbtpro.indicators.factory import IndicatorFactory, IndicatorBase, CacheOutputT from vectorbtpro.registries.jit_registry import jit_reg from vectorbtpro.signals.enums import FactoryMode from vectorbtpro.signals.nb import generate_nb, generate_ex_nb, generate_enex_nb, first_place_nb from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.enum_ import map_enum_fields from vectorbtpro.utils.params import to_typed_list __all__ = [ "SignalFactory", ] class SignalFactory(IndicatorFactory): """A factory for building signal generators. Extends `vectorbtpro.indicators.factory.IndicatorFactory` with place functions. Generates a fixed number of outputs (depending upon `mode`). If you need to generate other outputs, use in-place outputs (via `in_output_names`). See `vectorbtpro.signals.enums.FactoryMode` for supported generation modes. Other arguments are passed to `vectorbtpro.indicators.factory.IndicatorFactory`. """ def __init__( self, *args, mode: tp.Union[str, int] = FactoryMode.Both, input_names: tp.Optional[tp.Sequence[str]] = None, attr_settings: tp.KwargsLike = None, **kwargs, ) -> None: mode = map_enum_fields(mode, FactoryMode) if input_names is None: input_names = [] else: input_names = list(input_names) if attr_settings is None: attr_settings = {} if "entries" in input_names: raise ValueError("entries cannot be used in input_names") if "exits" in input_names: raise ValueError("exits cannot be used in input_names") if mode == FactoryMode.Entries: output_names = ["entries"] elif mode == FactoryMode.Exits: input_names = ["entries"] + input_names output_names = ["exits"] elif mode == FactoryMode.Both: output_names = ["entries", "exits"] else: input_names = ["entries"] + input_names output_names = ["new_entries", "exits"] if "entries" in input_names: attr_settings["entries"] = dict(dtype=np.bool_) for output_name in output_names: attr_settings[output_name] = dict(dtype=np.bool_) IndicatorFactory.__init__( self, *args, mode=mode, input_names=input_names, output_names=output_names, attr_settings=attr_settings, **kwargs, ) self._mode = mode def plot( _self, column: tp.Optional[tp.Label] = None, entry_y: tp.Union[None, str, tp.ArrayLike] = None, exit_y: tp.Union[None, str, tp.ArrayLike] = None, entry_types: tp.Optional[tp.ArrayLike] = None, exit_types: tp.Optional[tp.ArrayLike] = None, entry_trace_kwargs: tp.KwargsLike = None, exit_trace_kwargs: tp.KwargsLike = None, fig: tp.Optional[tp.BaseFigure] = None, **kwargs, ) -> tp.BaseFigure: self_col = _self.select_col(column=column, group_by=False) if entry_y is not None and isinstance(entry_y, str): entry_y = getattr(self_col, entry_y) if exit_y is not None and isinstance(exit_y, str): exit_y = getattr(self_col, exit_y) if entry_trace_kwargs is None: entry_trace_kwargs = {} if exit_trace_kwargs is None: exit_trace_kwargs = {} entry_trace_kwargs = merge_dicts( dict(name="New Entries" if mode == FactoryMode.Chain else "Entries"), entry_trace_kwargs, ) exit_trace_kwargs = merge_dicts(dict(name="Exits"), exit_trace_kwargs) if entry_types is not None: entry_types = np.asarray(entry_types) entry_trace_kwargs = merge_dicts( dict(customdata=entry_types, hovertemplate="(%{x}, %{y})
Type: %{customdata}"), entry_trace_kwargs, ) if exit_types is not None: exit_types = np.asarray(exit_types) exit_trace_kwargs = merge_dicts( dict(customdata=exit_types, hovertemplate="(%{x}, %{y})
Type: %{customdata}"), exit_trace_kwargs, ) if mode == FactoryMode.Entries: fig = self_col.entries.vbt.signals.plot_as_entries( y=entry_y, trace_kwargs=entry_trace_kwargs, fig=fig, **kwargs, ) elif mode == FactoryMode.Exits: fig = self_col.entries.vbt.signals.plot_as_entries( y=entry_y, trace_kwargs=entry_trace_kwargs, fig=fig, **kwargs, ) fig = self_col.exits.vbt.signals.plot_as_exits( y=exit_y, trace_kwargs=exit_trace_kwargs, fig=fig, **kwargs, ) elif mode == FactoryMode.Both: fig = self_col.entries.vbt.signals.plot_as_entries( y=entry_y, trace_kwargs=entry_trace_kwargs, fig=fig, **kwargs, ) fig = self_col.exits.vbt.signals.plot_as_exits( y=exit_y, trace_kwargs=exit_trace_kwargs, fig=fig, **kwargs, ) else: fig = self_col.new_entries.vbt.signals.plot_as_entries( y=entry_y, trace_kwargs=entry_trace_kwargs, fig=fig, **kwargs, ) fig = self_col.exits.vbt.signals.plot_as_exits( y=exit_y, trace_kwargs=exit_trace_kwargs, fig=fig, **kwargs, ) return fig plot.__doc__ = """Plot `{0}.{1}` and `{0}.exits`. Args: entry_y (array_like): Y-axis values to plot entry markers on. exit_y (array_like): Y-axis values to plot exit markers on. entry_types (array_like): Entry types in string format. exit_types (array_like): Exit types in string format. entry_trace_kwargs (dict): Keyword arguments passed to `vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_entries` for `{0}.{1}`. exit_trace_kwargs (dict): Keyword arguments passed to `vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_exits` for `{0}.exits`. fig (Figure or FigureWidget): Figure to add traces to. **kwargs: Keyword arguments passed to `vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_markers`. """.format( self.class_name, "new_entries" if mode == FactoryMode.Chain else "entries", ) setattr(self.Indicator, "plot", plot) @property def mode(self): """Factory mode.""" return self._mode def with_place_func( self, entry_place_func_nb: tp.Optional[tp.PlaceFunc] = None, exit_place_func_nb: tp.Optional[tp.PlaceFunc] = None, generate_func_nb: tp.Optional[tp.Callable] = None, generate_ex_func_nb: tp.Optional[tp.Callable] = None, generate_enex_func_nb: tp.Optional[tp.Callable] = None, cache_func: tp.Callable = None, entry_settings: tp.KwargsLike = None, exit_settings: tp.KwargsLike = None, cache_settings: tp.KwargsLike = None, jit_kwargs: tp.KwargsLike = None, jitted: tp.JittedOption = None, **kwargs, ) -> tp.Type[IndicatorBase]: """Build signal generator class around entry and exit placement functions. A placement function is simply a function that places signals. There are two types of it: entry placement function and exit placement function. Each placement function takes broadcast time series, broadcast in-place output time series, broadcast parameter arrays, and other arguments, and returns an array of indices corresponding to chosen signals. See `vectorbtpro.signals.nb.generate_nb`. Args: entry_place_func_nb (callable): `place_func_nb` that returns indices of entries. Defaults to `vectorbtpro.signals.nb.first_place_nb` for `FactoryMode.Chain`. exit_place_func_nb (callable): `place_func_nb` that returns indices of exits. generate_func_nb (callable): Entry generation function. Defaults to `vectorbtpro.signals.nb.generate_nb`. generate_ex_func_nb (callable): Exit generation function. Defaults to `vectorbtpro.signals.nb.generate_ex_nb`. generate_enex_func_nb (callable): Entry and exit generation function. Defaults to `vectorbtpro.signals.nb.generate_enex_nb`. cache_func (callable): A caching function to preprocess data beforehand. All returned objects will be passed as last arguments to placement functions. entry_settings (dict): Settings dict for `entry_place_func_nb`. exit_settings (dict): Settings dict for `exit_place_func_nb`. cache_settings (dict): Settings dict for `cache_func`. jit_kwargs (dict): Keyword arguments passed to `@njit` decorator of the parameter selection function. By default, has `nogil` set to True. jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`. Gets applied to generation functions only. If the respective generation function is not jitted, then the apply function won't be jitted as well. **kwargs: Keyword arguments passed to `IndicatorFactory.with_custom_func`. !!! note Choice functions must be Numba-compiled. Which inputs, parameters and arguments to pass to each function must be explicitly indicated in the function's settings dict. By default, nothing is passed. Passing keyword arguments directly to the placement functions is not supported. Use `pass_kwargs` in a settings dict to pass keyword arguments as positional. Settings dict of each function can have the following keys: Attributes: pass_inputs (list of str): Input names to pass to the placement function. Defaults to []. Order matters. Each name must be in `input_names`. pass_in_outputs (list of str): In-place output names to pass to the placement function. Defaults to []. Order matters. Each name must be in `in_output_names`. pass_params (list of str): Parameter names to pass to the placement function. Defaults to []. Order matters. Each name must be in `param_names`. pass_kwargs (dict, list of str or list of tuple): Keyword arguments from `kwargs` dict to pass as positional arguments to the placement function. Defaults to []. Order matters. If any element is a tuple, must contain the name and the default value. If any element is a string, the default value is None. Built-in keys include: * `input_shape`: Input shape if no input time series passed. Default is provided by the pipeline if `pass_input_shape` is True. * `wait`: Number of ticks to wait before placing signals. Default is 1. * `until_next`: Whether to place signals up to the next entry signal. Default is True. Applied in `generate_ex_func_nb` only. * `skip_until_exit`: Whether to skip processing entry signals until the next exit. Default is False. Applied in `generate_ex_func_nb` only. * `pick_first`: Whether to stop as soon as the first exit signal is found. Default is False with `FactoryMode.Entries`, otherwise is True. * `temp_idx_arr`: Empty integer array used to temporarily store indices. Default is an automatically generated array of shape `input_shape[0]`. You can also pass `temp_idx_arr1`, `temp_idx_arr2`, etc. to generate multiple. pass_cache (bool): Whether to pass cache from `cache_func` to the placement function. Defaults to False. Cache is passed unpacked. The following arguments can be passed to `run` and `run_combs` methods: Args: *args: Can be used instead of `place_args`. place_args (tuple): Arguments passed to any placement function (depending on the mode). entry_place_args (tuple): Arguments passed to the entry placement function. exit_place_args (tuple): Arguments passed to the exit placement function. entry_args (tuple): Alias for `entry_place_args`. exit_args (tuple): Alias for `exit_place_args`. cache_args (tuple): Arguments passed to the cache function. entry_kwargs (tuple): Settings for the entry placement function. Also contains arguments passed as positional if in `pass_kwargs`. exit_kwargs (tuple): Settings for the exit placement function. Also contains arguments passed as positional if in `pass_kwargs`. cache_kwargs (tuple): Settings for the cache function. Also contains arguments passed as positional if in `pass_kwargs`. return_cache (bool): Whether to return only cache. use_cache (any): Cache to use. **kwargs: Default keyword arguments (depending on the mode). For more arguments, see `vectorbtpro.indicators.factory.IndicatorBase.run_pipeline`. Usage: * The simplest signal indicator that places True at the very first index: ```pycon >>> from vectorbtpro import * >>> @njit ... def entry_place_func_nb(c): ... c.out[0] = True ... return 0 >>> @njit ... def exit_place_func_nb(c): ... c.out[0] = True ... return 0 >>> MySignals = vbt.SignalFactory().with_place_func( ... entry_place_func_nb=entry_place_func_nb, ... exit_place_func_nb=exit_place_func_nb, ... entry_kwargs=dict(wait=1), ... exit_kwargs=dict(wait=1) ... ) >>> my_sig = MySignals.run(input_shape=(3, 3)) >>> my_sig.entries 0 1 2 0 True True True 1 False False False 2 True True True >>> my_sig.exits 0 1 2 0 False False False 1 True True True 2 False False False ``` * Take the first entry and place an exit after waiting `n` ticks. Find the next entry and repeat. Test three different `n` values. ```pycon >>> from vectorbtpro.signals.factory import SignalFactory >>> @njit ... def wait_place_nb(c, n): ... if n < len(c.out): ... c.out[n] = True ... return n ... return -1 >>> # Build signal generator >>> MySignals = SignalFactory( ... mode='chain', ... param_names=['n'] ... ).with_place_func( ... exit_place_func_nb=wait_place_nb, ... exit_settings=dict( ... pass_params=['n'] ... ) ... ) >>> # Run signal generator >>> entries = [True, True, True, True, True] >>> my_sig = MySignals.run(entries, [0, 1, 2]) >>> my_sig.entries # input entries custom_n 0 1 2 0 True True True 1 True True True 2 True True True 3 True True True 4 True True True >>> my_sig.new_entries # output entries custom_n 0 1 2 0 True True True 1 False False False 2 True False False 3 False True False 4 True False True >>> my_sig.exits # output exits custom_n 0 1 2 0 False False False 1 True False False 2 False True False 3 True False True 4 False False False ``` * To combine multiple iterative signals, you would need to create a custom placement function. Here is an example of combining two random generators using "OR" rule (the first signal wins): ```pycon >>> from vectorbtpro.indicators.configs import flex_elem_param_config >>> from vectorbtpro.signals.factory import SignalFactory >>> from vectorbtpro.signals.nb import rand_by_prob_place_nb >>> # Enum to distinguish random generators >>> RandType = namedtuple('RandType', ['R1', 'R2'])(0, 1) >>> # Define exit placement function >>> @njit ... def rand_exit_place_nb(c, rand_type, prob1, prob2): ... for out_i in range(len(c.out)): ... if np.random.uniform(0, 1) < prob1: ... c.out[out_i] = True ... rand_type[c.from_i + out_i] = RandType.R1 ... return out_i ... if np.random.uniform(0, 1) < prob2: ... c.out[out_i] = True ... rand_type[c.from_i + out_i] = RandType.R2 ... return out_i ... return -1 >>> # Build signal generator >>> MySignals = SignalFactory( ... mode='chain', ... in_output_names=['rand_type'], ... param_names=['prob1', 'prob2'], ... attr_settings=dict( ... rand_type=dict(dtype=RandType) # creates rand_type_readable ... ) ... ).with_place_func( ... exit_place_func_nb=rand_exit_place_nb, ... exit_settings=dict( ... pass_in_outputs=['rand_type'], ... pass_params=['prob1', 'prob2'] ... ), ... param_settings=dict( ... prob1=flex_elem_param_config, # param per frame/row/col/element ... prob2=flex_elem_param_config ... ), ... rand_type=-1 # fill with this value ... ) >>> # Run signal generator >>> entries = [True, True, True, True, True] >>> my_sig = MySignals.run(entries, [0., 1.], [0., 1.], param_product=True) >>> my_sig.new_entries custom_prob1 0.0 1.0 custom_prob2 0.0 1.0 0.0 1.0 0 True True True True 1 False False False False 2 False True True True 3 False False False False 4 False True True True >>> my_sig.exits custom_prob1 0.0 1.0 custom_prob2 0.0 1.0 0.0 1.0 0 False False False False 1 False True True True 2 False False False False 3 False True True True 4 False False False False >>> my_sig.rand_type_readable custom_prob1 0.0 1.0 custom_prob2 0.0 1.0 0.0 1.0 0 1 R2 R1 R1 2 3 R2 R1 R1 4 ``` """ Indicator = self.Indicator setattr(Indicator, "entry_place_func_nb", entry_place_func_nb) setattr(Indicator, "exit_place_func_nb", exit_place_func_nb) module_name = self.module_name mode = self.mode input_names = self.input_names param_names = self.param_names in_output_names = self.in_output_names if generate_func_nb is None: generate_func_nb = generate_nb if generate_ex_func_nb is None: generate_ex_func_nb = generate_ex_nb if generate_enex_func_nb is None: generate_enex_func_nb = generate_enex_nb if jitted is not None: generate_func_nb = jit_reg.resolve_option(generate_func_nb, jitted) generate_ex_func_nb = jit_reg.resolve_option(generate_ex_func_nb, jitted) generate_enex_func_nb = jit_reg.resolve_option(generate_enex_func_nb, jitted) default_chain_entry_func = True if mode == FactoryMode.Entries: jit_apply_func = checks.is_numba_func(generate_func_nb) require_input_shape = len(input_names) == 0 checks.assert_not_none(entry_place_func_nb, arg_name="entry_place_func_nb") if exit_place_func_nb is not None: raise ValueError("exit_place_func_nb cannot be used with FactoryMode.Entries") elif mode == FactoryMode.Exits: jit_apply_func = checks.is_numba_func(generate_ex_func_nb) require_input_shape = False if entry_place_func_nb is not None: raise ValueError("entry_place_func_nb cannot be used with FactoryMode.Exits") checks.assert_not_none(exit_place_func_nb, arg_name="exit_place_func_nb") elif mode == FactoryMode.Both: jit_apply_func = checks.is_numba_func(generate_enex_func_nb) require_input_shape = len(input_names) == 0 checks.assert_not_none(entry_place_func_nb, arg_name="entry_place_func_nb") checks.assert_not_none(exit_place_func_nb, arg_name="exit_place_func_nb") else: jit_apply_func = checks.is_numba_func(generate_enex_func_nb) require_input_shape = False if entry_place_func_nb is None: entry_place_func_nb = first_place_nb else: default_chain_entry_func = False if entry_settings is None: entry_settings = {} entry_settings = merge_dicts(dict(pass_inputs=["entries"]), entry_settings) checks.assert_not_none(entry_place_func_nb, arg_name="entry_place_func_nb") checks.assert_not_none(exit_place_func_nb, arg_name="exit_place_func_nb") require_input_shape = kwargs.pop("require_input_shape", require_input_shape) if entry_settings is None: entry_settings = {} if exit_settings is None: exit_settings = {} if cache_settings is None: cache_settings = {} valid_keys = ["pass_inputs", "pass_in_outputs", "pass_params", "pass_kwargs", "pass_cache"] checks.assert_dict_valid(entry_settings, valid_keys) checks.assert_dict_valid(exit_settings, valid_keys) checks.assert_dict_valid(cache_settings, valid_keys) # Get input names for each function def _get_func_names(func_settings: tp.Kwargs, setting: str, all_names: tp.Sequence[str]) -> tp.List[str]: func_input_names = func_settings.get(setting, None) if func_input_names is None: return [] else: for name in func_input_names: checks.assert_in(name, all_names) return func_input_names entry_input_names = _get_func_names(entry_settings, "pass_inputs", input_names) exit_input_names = _get_func_names(exit_settings, "pass_inputs", input_names) cache_input_names = _get_func_names(cache_settings, "pass_inputs", input_names) entry_in_output_names = _get_func_names(entry_settings, "pass_in_outputs", in_output_names) exit_in_output_names = _get_func_names(exit_settings, "pass_in_outputs", in_output_names) cache_in_output_names = _get_func_names(cache_settings, "pass_in_outputs", in_output_names) entry_param_names = _get_func_names(entry_settings, "pass_params", param_names) exit_param_names = _get_func_names(exit_settings, "pass_params", param_names) cache_param_names = _get_func_names(cache_settings, "pass_params", param_names) # Build a function that selects a parameter tuple if mode == FactoryMode.Entries: _0 = "i" _0 += ", target_shape" if len(entry_input_names) > 0: _0 += ", " + ", ".join(entry_input_names) if len(entry_in_output_names) > 0: _0 += ", " + ", ".join(entry_in_output_names) if len(entry_param_names) > 0: _0 += ", " + ", ".join(entry_param_names) _0 += ", entry_args" _0 += ", only_once" _0 += ", wait" _1 = "target_shape=target_shape" _1 += ", place_func_nb=entry_place_func_nb" _1 += ", place_args=(" if len(entry_input_names) > 0: _1 += ", ".join(entry_input_names) + ", " if len(entry_in_output_names) > 0: _1 += ", ".join(map(lambda x: x + "[i]", entry_in_output_names)) + ", " if len(entry_param_names) > 0: _1 += ", ".join(map(lambda x: x + "[i]", entry_param_names)) + ", " _1 += "*entry_args,)" _1 += ", only_once=only_once" _1 += ", wait=wait" func_str = "def apply_func({0}):\n return generate_func_nb({1})".format(_0, _1) scope = {"generate_func_nb": generate_func_nb, "entry_place_func_nb": entry_place_func_nb} elif mode == FactoryMode.Exits: _0 = "i" _0 += ", entries" if len(exit_input_names) > 0: _0 += ", " + ", ".join(exit_input_names) if len(exit_in_output_names) > 0: _0 += ", " + ", ".join(exit_in_output_names) if len(exit_param_names) > 0: _0 += ", " + ", ".join(exit_param_names) _0 += ", exit_args" _0 += ", wait" _0 += ", until_next" _0 += ", skip_until_exit" _1 = "entries=entries" _1 += ", exit_place_func_nb=exit_place_func_nb" _1 += ", exit_place_args=(" if len(exit_input_names) > 0: _1 += ", ".join(exit_input_names) + ", " if len(exit_in_output_names) > 0: _1 += ", ".join(map(lambda x: x + "[i]", exit_in_output_names)) + ", " if len(exit_param_names) > 0: _1 += ", ".join(map(lambda x: x + "[i]", exit_param_names)) + ", " _1 += "*exit_args,)" _1 += ", wait=wait" _1 += ", until_next=until_next" _1 += ", skip_until_exit=skip_until_exit" func_str = "def apply_func({0}):\n return generate_ex_func_nb({1})".format(_0, _1) scope = {"generate_ex_func_nb": generate_ex_func_nb, "exit_place_func_nb": exit_place_func_nb} else: _0 = "i" _0 += ", target_shape" if len(entry_input_names) > 0: _0 += ", " + ", ".join(map(lambda x: "_entry_" + x, entry_input_names)) if len(entry_in_output_names) > 0: _0 += ", " + ", ".join(map(lambda x: "_entry_" + x, entry_in_output_names)) if len(entry_param_names) > 0: _0 += ", " + ", ".join(map(lambda x: "_entry_" + x, entry_param_names)) _0 += ", entry_args" if len(exit_input_names) > 0: _0 += ", " + ", ".join(map(lambda x: "_exit_" + x, exit_input_names)) if len(exit_in_output_names) > 0: _0 += ", " + ", ".join(map(lambda x: "_exit_" + x, exit_in_output_names)) if len(exit_param_names) > 0: _0 += ", " + ", ".join(map(lambda x: "_exit_" + x, exit_param_names)) _0 += ", exit_args" _0 += ", entry_wait" _0 += ", exit_wait" _1 = "target_shape=target_shape" _1 += ", entry_place_func_nb=entry_place_func_nb" _1 += ", entry_place_args=(" if len(entry_input_names) > 0: _1 += ", ".join(map(lambda x: "_entry_" + x, entry_input_names)) + ", " if len(entry_in_output_names) > 0: _1 += ", ".join(map(lambda x: "_entry_" + x + "[i]", entry_in_output_names)) + ", " if len(entry_param_names) > 0: _1 += ", ".join(map(lambda x: "_entry_" + x + "[i]", entry_param_names)) + ", " _1 += "*entry_args,)" _1 += ", exit_place_func_nb=exit_place_func_nb" _1 += ", exit_place_args=(" if len(exit_input_names) > 0: _1 += ", ".join(map(lambda x: "_exit_" + x, exit_input_names)) + ", " if len(exit_in_output_names) > 0: _1 += ", ".join(map(lambda x: "_exit_" + x + "[i]", exit_in_output_names)) + ", " if len(exit_param_names) > 0: _1 += ", ".join(map(lambda x: "_exit_" + x + "[i]", exit_param_names)) + ", " _1 += "*exit_args,)" _1 += ", entry_wait=entry_wait" _1 += ", exit_wait=exit_wait" func_str = "def apply_func({0}):\n return generate_enex_func_nb({1})".format(_0, _1) scope = { "generate_enex_func_nb": generate_enex_func_nb, "entry_place_func_nb": entry_place_func_nb, "exit_place_func_nb": exit_place_func_nb, } filename = inspect.getfile(lambda: None) code = compile(func_str, filename, "single") exec(code, scope) apply_func = scope["apply_func"] if module_name is not None: apply_func.__module__ = module_name if jit_apply_func: jit_kwargs = merge_dicts(dict(nogil=True), jit_kwargs) apply_func = njit(apply_func, **jit_kwargs) setattr(Indicator, "apply_func", apply_func) def custom_func( input_list: tp.List[tp.AnyArray], in_output_list: tp.List[tp.List[tp.AnyArray]], param_list: tp.List[tp.List[tp.ParamValue]], *args, input_shape: tp.Optional[tp.Shape] = None, place_args: tp.ArgsLike = None, entry_place_args: tp.ArgsLike = None, exit_place_args: tp.ArgsLike = None, entry_args: tp.ArgsLike = None, exit_args: tp.ArgsLike = None, cache_args: tp.ArgsLike = None, entry_kwargs: tp.KwargsLike = None, exit_kwargs: tp.KwargsLike = None, cache_kwargs: tp.KwargsLike = None, return_cache: bool = False, use_cache: tp.Optional[CacheOutputT] = None, execute_kwargs: tp.KwargsLike = None, **_kwargs, ) -> tp.Union[CacheOutputT, tp.Array2d, tp.List[tp.Array2d]]: # Get arguments if len(input_list) == 0: if input_shape is None: raise ValueError("Pass input_shape if no input time series were passed") else: input_shape = input_list[0].shape if len(args) > 0 and place_args is not None: raise ValueError("Must provide either *args or place_args, not both") if place_args is None: place_args = args if ( mode == FactoryMode.Entries or mode == FactoryMode.Both or (mode == FactoryMode.Chain and not default_chain_entry_func) ): if len(place_args) > 0 and entry_place_args is not None: raise ValueError("Must provide either place_args or entry_place_args, not both") if entry_place_args is None: entry_place_args = place_args else: if entry_place_args is None: entry_place_args = () if mode in (FactoryMode.Exits, FactoryMode.Both, FactoryMode.Chain): if len(place_args) > 0 and exit_place_args is not None: raise ValueError("Must provide either place_args or exit_place_args, not both") if exit_place_args is None: exit_place_args = place_args else: if exit_place_args is None: exit_place_args = () if len(entry_place_args) > 0 and entry_args is not None: raise ValueError("Must provide either entry_place_args or entry_args, not both") if entry_args is None: entry_args = entry_place_args if len(exit_place_args) > 0 and exit_args is not None: raise ValueError("Must provide either exit_place_args or exit_args, not both") if exit_args is None: exit_args = exit_place_args if cache_args is None: cache_args = () if ( mode == FactoryMode.Entries or mode == FactoryMode.Both or (mode == FactoryMode.Chain and not default_chain_entry_func) ): entry_kwargs = merge_dicts(_kwargs, entry_kwargs) else: if entry_kwargs is None: entry_kwargs = {} if mode in (FactoryMode.Exits, FactoryMode.Both, FactoryMode.Chain): exit_kwargs = merge_dicts(_kwargs, exit_kwargs) else: if exit_kwargs is None: exit_kwargs = {} if cache_kwargs is None: cache_kwargs = {} kwargs_defaults = dict( input_shape=input_shape, only_once=mode == FactoryMode.Entries, wait=1, until_next=True, skip_until_exit=False, pick_first=mode != FactoryMode.Entries, ) entry_kwargs = merge_dicts(kwargs_defaults, entry_kwargs) exit_kwargs = merge_dicts(kwargs_defaults, exit_kwargs) cache_kwargs = merge_dicts(kwargs_defaults, cache_kwargs) only_once = entry_kwargs["only_once"] entry_wait = entry_kwargs["wait"] exit_wait = exit_kwargs["wait"] until_next = exit_kwargs["until_next"] skip_until_exit = exit_kwargs["skip_until_exit"] # Distribute arguments across functions entry_input_list = [] exit_input_list = [] cache_input_list = [] for input_name in entry_input_names: entry_input_list.append(input_list[input_names.index(input_name)]) for input_name in exit_input_names: exit_input_list.append(input_list[input_names.index(input_name)]) for input_name in cache_input_names: cache_input_list.append(input_list[input_names.index(input_name)]) entry_in_output_list = [] exit_in_output_list = [] cache_in_output_list = [] for in_output_name in entry_in_output_names: entry_in_output_list.append(in_output_list[in_output_names.index(in_output_name)]) for in_output_name in exit_in_output_names: exit_in_output_list.append(in_output_list[in_output_names.index(in_output_name)]) for in_output_name in cache_in_output_names: cache_in_output_list.append(in_output_list[in_output_names.index(in_output_name)]) entry_param_list = [] exit_param_list = [] cache_param_list = [] for param_name in entry_param_names: entry_param_list.append(param_list[param_names.index(param_name)]) for param_name in exit_param_names: exit_param_list.append(param_list[param_names.index(param_name)]) for param_name in cache_param_names: cache_param_list.append(param_list[param_names.index(param_name)]) n_params = len(param_list[0]) if len(param_list) > 0 else 1 def _build_more_args(func_settings: tp.Kwargs, func_kwargs: tp.Kwargs) -> tp.Args: pass_kwargs = func_settings.get("pass_kwargs", []) if isinstance(pass_kwargs, dict): pass_kwargs = list(pass_kwargs.items()) more_args = () for key in pass_kwargs: value = None if isinstance(key, tuple): key, value = key else: if key.startswith("temp_idx_arr"): value = np.empty((input_shape[0],), dtype=int_) value = func_kwargs.get(key, value) more_args += (value,) return more_args entry_more_args = _build_more_args(entry_settings, entry_kwargs) exit_more_args = _build_more_args(exit_settings, exit_kwargs) cache_more_args = _build_more_args(cache_settings, cache_kwargs) # Caching cache = use_cache if cache is None and cache_func is not None: _cache_in_output_list = cache_in_output_list _cache_param_list = cache_param_list if checks.is_numba_func(cache_func): _cache_in_output_list = list(map(to_typed_list, cache_in_output_list)) _cache_param_list = list(map(to_typed_list, cache_param_list)) cache = cache_func( *cache_input_list, *_cache_in_output_list, *_cache_param_list, *cache_args, *cache_more_args, ) if return_cache: return cache if cache is None: cache = () if not isinstance(cache, tuple): cache = (cache,) entry_cache = () exit_cache = () if entry_settings.get("pass_cache", False): entry_cache = cache if exit_settings.get("pass_cache", False): exit_cache = cache # Apply and concatenate if mode == FactoryMode.Entries: _entry_in_output_list = list(map(to_typed_list, entry_in_output_list)) _entry_param_list = list(map(to_typed_list, entry_param_list)) return combining.apply_and_concat( n_params, apply_func, input_shape, *entry_input_list, *_entry_in_output_list, *_entry_param_list, entry_args + entry_more_args + entry_cache, only_once, entry_wait, n_outputs=1, jitted_loop=jit_apply_func, execute_kwargs=execute_kwargs, ) elif mode == FactoryMode.Exits: _exit_in_output_list = list(map(to_typed_list, exit_in_output_list)) _exit_param_list = list(map(to_typed_list, exit_param_list)) return combining.apply_and_concat( n_params, apply_func, input_list[0], *exit_input_list, *_exit_in_output_list, *_exit_param_list, exit_args + exit_more_args + exit_cache, exit_wait, until_next, skip_until_exit, n_outputs=1, jitted_loop=jit_apply_func, execute_kwargs=execute_kwargs, ) else: _entry_in_output_list = list(map(to_typed_list, entry_in_output_list)) _entry_param_list = list(map(to_typed_list, entry_param_list)) _exit_in_output_list = list(map(to_typed_list, exit_in_output_list)) _exit_param_list = list(map(to_typed_list, exit_param_list)) return combining.apply_and_concat( n_params, apply_func, input_shape, *entry_input_list, *_entry_in_output_list, *_entry_param_list, entry_args + entry_more_args + entry_cache, *exit_input_list, *_exit_in_output_list, *_exit_param_list, exit_args + exit_more_args + exit_cache, entry_wait, exit_wait, n_outputs=2, jitted_loop=jit_apply_func, execute_kwargs=execute_kwargs, ) return self.with_custom_func( custom_func, pass_packed=True, require_input_shape=require_input_shape, **kwargs, )
# ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Numba-compiled functions for signals. Provides an arsenal of Numba-compiled functions that are used by accessors and in many other parts of the backtesting pipeline, such as technical indicators. These only accept NumPy arrays and other Numba-compatible types. !!! note vectorbt treats matrices as first-class citizens and expects input arrays to be 2-dim, unless function has suffix `_1d` or is meant to be input to another function. Data is processed along index (axis 0). All functions passed as argument must be Numba-compiled.""" import numpy as np from numba import prange from vectorbtpro import _typing as tp from vectorbtpro._dtypes import * from vectorbtpro.base import chunking as base_ch from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb, flex_select_nb from vectorbtpro.generic import nb as generic_nb from vectorbtpro.generic.enums import range_dt, RangeStatus from vectorbtpro.records import chunking as records_ch from vectorbtpro.registries.ch_registry import register_chunkable from vectorbtpro.registries.jit_registry import register_jitted from vectorbtpro.signals.enums import * from vectorbtpro.utils import chunking as ch from vectorbtpro.utils.array_ import uniform_summing_to_one_nb, rescale_float_to_int_nb, rescale_nb from vectorbtpro.utils.template import Rep __all__ = [] # ############# Generation ############# # @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), place_func_nb=None, place_args=ch.ArgsTaker(), only_once=None, wait=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def generate_nb( target_shape: tp.Shape, place_func_nb: tp.PlaceFunc, place_args: tp.Args = (), only_once: bool = True, wait: int = 1, ) -> tp.Array2d: """Create a boolean matrix of `target_shape` and place signals using `place_func_nb`. Args: target_shape (array): Target shape. place_func_nb (callable): Signal placement function. `place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenEnContext`, and return the index of the last signal (-1 to break the loop). place_args: Arguments passed to `place_func_nb`. only_once (bool): Whether to run the placement function only once. wait (int): Number of ticks to wait before placing the next entry. !!! note The first argument is always a 1-dimensional boolean array that contains only those elements where signals can be placed. The range and column indices only describe which range this array maps to. """ if wait < 0: raise ValueError("wait must be zero or greater") out = np.full(target_shape, False, dtype=np.bool_) for col in prange(target_shape[1]): from_i = 0 while from_i <= target_shape[0] - 1: c = GenEnContext( target_shape=target_shape, only_once=only_once, wait=wait, entries_out=out, out=out[from_i:, col], from_i=from_i, to_i=target_shape[0], col=col, ) _last_i = place_func_nb(c, *place_args) if _last_i == -1: break last_i = from_i + _last_i if last_i < from_i or last_i >= target_shape[0]: raise ValueError("Last index is out of bounds") if not out[last_i, col]: out[last_i, col] = True if only_once: break from_i = last_i + wait return out @register_chunkable( size=ch.ArraySizer(arg_query="entries", axis=1), arg_take_spec=dict( entries=ch.ArraySlicer(axis=1), exit_place_func_nb=None, exit_place_args=ch.ArgsTaker(), wait=None, until_next=None, skip_until_exit=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def generate_ex_nb( entries: tp.Array2d, exit_place_func_nb: tp.PlaceFunc, exit_place_args: tp.Args = (), wait: int = 1, until_next: bool = True, skip_until_exit: bool = False, ) -> tp.Array2d: """Place exit signals using `exit_place_func_nb` after each signal in `entries`. Args: entries (array): Boolean array with entry signals. exit_place_func_nb (callable): Exit place function. `exit_place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenExContext`, and return the index of the last signal (-1 to break the loop). exit_place_args (callable): Arguments passed to `exit_place_func_nb`. wait (int): Number of ticks to wait before placing exits. !!! note Setting `wait` to 0 or False may result in two signals at one bar. until_next (int): Whether to place signals up to the next entry signal. !!! note Setting it to False makes it difficult to tell which exit belongs to which entry. skip_until_exit (bool): Whether to skip processing entry signals until the next exit. Has only effect when `until_next` is disabled. !!! note Setting it to True makes it impossible to tell which exit belongs to which entry. """ if wait < 0: raise ValueError("wait must be zero or greater") out = np.full_like(entries, False) def _place_exits(from_i, to_i, col, last_exit_i): if from_i > -1: if skip_until_exit and from_i <= last_exit_i: return last_exit_i from_i += wait if not until_next: to_i = entries.shape[0] if to_i > from_i: c = GenExContext( entries=out, until_next=until_next, skip_until_exit=skip_until_exit, exits_out=out, out=out[from_i:to_i, col], wait=wait, from_i=from_i, to_i=to_i, col=col, ) _last_exit_i = exit_place_func_nb(c, *exit_place_args) if _last_exit_i != -1: last_exit_i = from_i + _last_exit_i if last_exit_i < from_i or last_exit_i >= entries.shape[0]: raise ValueError("Last index is out of bounds") if not out[last_exit_i, col]: out[last_exit_i, col] = True elif skip_until_exit: last_exit_i = -1 return last_exit_i for col in prange(entries.shape[1]): from_i = -1 last_exit_i = -1 should_stop = False for i in range(entries.shape[0]): if entries[i, col]: last_exit_i = _place_exits(from_i, i, col, last_exit_i) if skip_until_exit and last_exit_i == -1 and from_i != -1: should_stop = True break from_i = i if should_stop: continue last_exit_i = _place_exits(from_i, entries.shape[0], col, last_exit_i) return out @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), entry_place_func_nb=None, entry_place_args=ch.ArgsTaker(), exit_place_func_nb=None, exit_place_args=ch.ArgsTaker(), entry_wait=None, exit_wait=None, ), merge_func="column_stack", ) @register_jitted def generate_enex_nb( target_shape: tp.Shape, entry_place_func_nb: tp.PlaceFunc, entry_place_args: tp.Args, exit_place_func_nb: tp.PlaceFunc, exit_place_args: tp.Args, entry_wait: int = 1, exit_wait: int = 1, ) -> tp.Tuple[tp.Array2d, tp.Array2d]: """Place entry signals using `entry_place_func_nb` and exit signals using `exit_place_func_nb` one after another. Args: target_shape (array): Target shape. entry_place_func_nb (callable): Entry place function. `entry_place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenEnExContext`, and return the index of the last signal (-1 to break the loop). entry_place_args (tuple): Arguments unpacked and passed to `entry_place_func_nb`. exit_place_func_nb (callable): Exit place function. `exit_place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenEnExContext`, and return the index of the last signal (-1 to break the loop). exit_place_args (tuple): Arguments unpacked and passed to `exit_place_func_nb`. entry_wait (int): Number of ticks to wait before placing entries. !!! note Setting `entry_wait` to 0 or False assumes that both entry and exit can be processed within the same bar, and exit can be processed before entry. exit_wait (int): Number of ticks to wait before placing exits. !!! note Setting `exit_wait` to 0 or False assumes that both entry and exit can be processed within the same bar, and entry can be processed before exit. """ if entry_wait < 0: raise ValueError("entry_wait must be zero or greater") if exit_wait < 0: raise ValueError("exit_wait must be zero or greater") if entry_wait == 0 and exit_wait == 0: raise ValueError("entry_wait and exit_wait cannot be both 0") entries = np.full(target_shape, False) exits = np.full(target_shape, False) def _place_signals(entries_turn, out, from_i, col, wait, place_func_nb, args): to_i = target_shape[0] if to_i > from_i: c = GenEnExContext( target_shape=target_shape, entry_wait=entry_wait, exit_wait=exit_wait, entries_out=entries, exits_out=exits, entries_turn=entries_turn, out=out[from_i:to_i, col], wait=wait if from_i > 0 else 0, from_i=from_i, to_i=to_i, col=col, ) _last_i = place_func_nb(c, *args) if _last_i == -1: return -1 last_i = from_i + _last_i if last_i < from_i or last_i >= target_shape[0]: raise ValueError("Last index is out of bounds") if not out[last_i, col]: out[last_i, col] = True return last_i return -1 for col in range(target_shape[1]): from_i = 0 entries_turn = True first_signal = True while from_i != -1: if entries_turn: if not first_signal: from_i += entry_wait from_i = _place_signals( entries_turn, entries, from_i, col, entry_wait, entry_place_func_nb, entry_place_args, ) entries_turn = False else: from_i += exit_wait from_i = _place_signals( entries_turn, exits, from_i, col, exit_wait, exit_place_func_nb, exit_place_args, ) entries_turn = True first_signal = False return entries, exits # ############# Random signals ############# # @register_jitted def rand_place_nb(c: tp.Union[GenEnContext, GenExContext, GenEnExContext], n: tp.FlexArray1d) -> int: """`place_func_nb` to randomly pick `n` values. `n` uses flexible indexing.""" size = min(c.to_i - c.from_i, flex_select_1d_pc_nb(n, c.col)) k = 0 last_i = -1 while k < size: i = np.random.choice(len(c.out)) if not c.out[i]: c.out[i] = True k += 1 if i > last_i: last_i = i return last_i @register_jitted def rand_by_prob_place_nb( c: tp.Union[GenEnContext, GenExContext, GenEnExContext], prob: tp.FlexArray2d, pick_first: bool = False, ) -> int: """`place_func_nb` to randomly place signals with probability `prob`. `prob` uses flexible indexing.""" last_i = -1 for i in range(c.from_i, c.to_i): if np.random.uniform(0, 1) < flex_select_nb(prob, i, c.col): c.out[i - c.from_i] = True last_i = i - c.from_i if pick_first: break return last_i @register_chunkable( size=ch.ShapeSizer(arg_query="target_shape", axis=1), arg_take_spec=dict( target_shape=ch.ShapeSlicer(axis=1), n=base_ch.FlexArraySlicer(), entry_wait=None, exit_wait=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def generate_rand_enex_nb( target_shape: tp.Shape, n: tp.FlexArray1d, entry_wait: int = 1, exit_wait: int = 1, ) -> tp.Tuple[tp.Array2d, tp.Array2d]: """Pick a number of entries and the same number of exits one after another. Respects `entry_wait` and `exit_wait` constraints through a number of tricks. Tries to mimic a uniform distribution as much as possible. The idea is the following: with constraints, there is some fixed amount of total space required between first entry and last exit. Upscale this space in a way that distribution of entries and exit is similar to a uniform distribution. This means randomizing the position of first entry, last exit, and all signals between them. `n` uses flexible indexing and thus must be at least a 0-dim array.""" entries = np.full(target_shape, False) exits = np.full(target_shape, False) if entry_wait == 0 and exit_wait == 0: raise ValueError("entry_wait and exit_wait cannot be both 0") if entry_wait == 1 and exit_wait == 1: # Basic case both = generate_nb(target_shape, rand_place_nb, (n * 2,), only_once=True, wait=1) for col in prange(both.shape[1]): both_idxs = np.flatnonzero(both[:, col]) entries[both_idxs[0::2], col] = True exits[both_idxs[1::2], col] = True else: for col in prange(target_shape[1]): _n = flex_select_1d_pc_nb(n, col) if _n == 1: entry_idx = np.random.randint(0, target_shape[0] - exit_wait) entries[entry_idx, col] = True else: # Minimum range between two entries min_range = entry_wait + exit_wait # Minimum total range between first and last entry min_total_range = min_range * (_n - 1) if target_shape[0] < min_total_range + exit_wait + 1: raise ValueError("Cannot take a larger sample than population") # We should decide how much space should be allocate before first and after last entry # Maximum space outside of min_total_range max_free_space = target_shape[0] - min_total_range - 1 # If min_total_range is tiny compared to max_free_space, limit it # otherwise we would have huge space before first and after last entry # Limit it such as distribution of entries mimics uniform free_space = min(max_free_space, 3 * target_shape[0] // (_n + 1)) # What about last exit? it requires exit_wait space free_space -= exit_wait # Now we need to distribute free space among three ranges: # 1) before first, 2) between first and last added to min_total_range, 3) after last # We do 2) such that min_total_range can freely expand to maximum # We allocate twice as much for 3) as for 1) because an exit is missing rand_floats = uniform_summing_to_one_nb(6) chosen_spaces = rescale_float_to_int_nb(rand_floats, (0, free_space), free_space) first_idx = chosen_spaces[0] last_idx = target_shape[0] - np.sum(chosen_spaces[-2:]) - exit_wait - 1 # Selected range between first and last entry total_range = last_idx - first_idx # Maximum range between two entries within total_range max_range = total_range - (_n - 2) * min_range # Select random ranges within total_range rand_floats = uniform_summing_to_one_nb(_n - 1) chosen_ranges = rescale_float_to_int_nb(rand_floats, (min_range, max_range), total_range) # Translate them into entries entry_idxs = np.empty(_n, dtype=int_) entry_idxs[0] = first_idx entry_idxs[1:] = chosen_ranges entry_idxs = np.cumsum(entry_idxs) entries[entry_idxs, col] = True # Generate exits for col in range(target_shape[1]): entry_idxs = np.flatnonzero(entries[:, col]) for j in range(len(entry_idxs)): entry_i = entry_idxs[j] + exit_wait if j < len(entry_idxs) - 1: exit_i = entry_idxs[j + 1] - entry_wait else: exit_i = entries.shape[0] - 1 i = np.random.randint(exit_i - entry_i + 1) exits[entry_i + i, col] = True return entries, exits def rand_enex_apply_nb( target_shape: tp.Shape, n: tp.FlexArray1d, entry_wait: int = 1, exit_wait: int = 1, ) -> tp.Tuple[tp.Array2d, tp.Array2d]: """`apply_func_nb` that calls `generate_rand_enex_nb`.""" return generate_rand_enex_nb(target_shape, n, entry_wait=entry_wait, exit_wait=exit_wait) # ############# Stop signals ############# # @register_jitted def first_place_nb(c: tp.Union[GenEnContext, GenExContext, GenEnExContext], mask: tp.Array2d) -> int: """`place_func_nb` that keeps only the first signal in `mask`.""" last_i = -1 for i in range(c.from_i, c.to_i): if mask[i, c.col]: c.out[i - c.from_i] = True last_i = i - c.from_i break return last_i @register_jitted def stop_place_nb( c: tp.Union[GenExContext, GenEnExContext], entry_ts: tp.FlexArray2d, ts: tp.FlexArray2d, follow_ts: tp.FlexArray2d, stop_ts_out: tp.Array2d, stop: tp.FlexArray2d, trailing: tp.FlexArray2d, ) -> int: """`place_func_nb` that places an exit signal whenever a threshold is being hit. !!! note Waiting time cannot be higher than 1. If waiting time is 0, `entry_ts` should be the first value in the bar. If waiting time is 1, `entry_ts` should be the last value in the bar. Args: c (GenExContext or GenEnExContext): Signal context. entry_ts (array of float): Entry price. Utilizes flexible indexing. ts (array of float): Price to compare the stop value against. Utilizes flexible indexing. If NaN, defaults to `entry_ts`. follow_ts (array of float): Following price. Utilizes flexible indexing. If NaN, defaults to `ts`. Applied only if the stop is trailing. stop_ts_out (array of float): Array where hit price of each exit will be stored. Must be of the full shape. stop (array of float): Stop value. Utilizes flexible indexing. Set an element to `np.nan` to disable it. trailing (array of bool): Whether the stop is trailing. Utilizes flexible indexing. Set an element to False to disable it. """ if c.wait > 1: raise ValueError("Wait must be either 0 or 1") init_i = c.from_i - c.wait init_entry_ts = flex_select_nb(entry_ts, init_i, c.col) init_stop = flex_select_nb(stop, init_i, c.col) if init_stop == 0: init_stop = np.nan init_trailing = flex_select_nb(trailing, init_i, c.col) max_high = min_low = init_entry_ts last_i = -1 for i in range(c.from_i, c.to_i): curr_entry_ts = flex_select_nb(entry_ts, i, c.col) curr_ts = flex_select_nb(ts, i, c.col) curr_follow_ts = flex_select_nb(follow_ts, i, c.col) if np.isnan(curr_ts): curr_ts = curr_entry_ts if np.isnan(curr_follow_ts): if not np.isnan(curr_entry_ts): if init_stop >= 0: curr_follow_ts = min(curr_entry_ts, curr_ts) else: curr_follow_ts = max(curr_entry_ts, curr_ts) else: curr_follow_ts = curr_ts if not np.isnan(init_stop): if init_trailing: if init_stop >= 0: # Trailing stop buy curr_stop_price = min_low * (1 + abs(init_stop)) else: # Trailing stop sell curr_stop_price = max_high * (1 - abs(init_stop)) else: curr_stop_price = init_entry_ts * (1 + init_stop) # Check if stop price is within bar if not np.isnan(init_stop): if init_stop >= 0: exit_signal = curr_ts >= curr_stop_price else: exit_signal = curr_ts <= curr_stop_price if exit_signal: stop_ts_out[i, c.col] = curr_stop_price c.out[i - c.from_i] = True last_i = i - c.from_i break # Keep track of lowest low and highest high if trailing if init_trailing: if curr_follow_ts < min_low: min_low = curr_follow_ts elif curr_follow_ts > max_high: max_high = curr_follow_ts return last_i @register_jitted def ohlc_stop_place_nb( c: tp.Union[GenExContext, GenEnExContext], entry_price: tp.FlexArray2d, open: tp.FlexArray2d, high: tp.FlexArray2d, low: tp.FlexArray2d, close: tp.FlexArray2d, stop_price_out: tp.Array2d, stop_type_out: tp.Array2d, sl_stop: tp.FlexArray2d, tsl_th: tp.FlexArray2d, tsl_stop: tp.FlexArray2d, tp_stop: tp.FlexArray2d, reverse: tp.FlexArray2d, is_entry_open: bool = False, ) -> int: """`place_func_nb` that places an exit signal whenever a threshold is being hit using OHLC. Compared to `stop_place_nb`, takes into account the whole bar, can check for both (trailing) stop loss and take profit simultaneously, and tracks hit price and stop type. !!! note Waiting time cannot be higher than 1. Args: c (GenExContext or GenEnExContext): Signal context. entry_price (array of float): Entry price. Utilizes flexible indexing. open (array of float): Open price. Utilizes flexible indexing. If Nan and `is_entry_open` is True, defaults to entry price. high (array of float): High price. Utilizes flexible indexing. If NaN, gets calculated from open and close. low (array of float): Low price. Utilizes flexible indexing. If NaN, gets calculated from open and close. close (array of float): Close price. Utilizes flexible indexing. If Nan and `is_entry_open` is False, defaults to entry price. stop_price_out (array of float): Array where hit price of each exit will be stored. Must be of the full shape. stop_type_out (array of int): Array where stop type of each exit will be stored. Must be of the full shape. 0 for stop loss, 1 for take profit. sl_stop (array of float): Stop loss as a percentage. Utilizes flexible indexing. Set an element to `np.nan` to disable. tsl_th (array of float): Take profit threshold as a percentage for the trailing stop loss. Utilizes flexible indexing. Set an element to `np.nan` to disable. tsl_stop (array of float): Trailing stop loss as a percentage for the trailing stop loss. Utilizes flexible indexing. Set an element to `np.nan` to disable. tp_stop (array of float): Take profit as a percentage. Utilizes flexible indexing. Set an element to `np.nan` to disable. reverse (array of float): Whether to do the opposite, i.e.: prices are followed downwards. Utilizes flexible indexing. is_entry_open (bool): Whether entry price comes right at or before open. If True, uses high and low of the entry bar. Otherwise, uses only close. """ if c.wait > 1: raise ValueError("Wait must be either 0 or 1") init_i = c.from_i - c.wait init_entry_price = flex_select_nb(entry_price, init_i, c.col) init_sl_stop = abs(flex_select_nb(sl_stop, init_i, c.col)) init_tp_stop = abs(flex_select_nb(tp_stop, init_i, c.col)) init_tsl_th = abs(flex_select_nb(tsl_th, init_i, c.col)) init_tsl_stop = abs(flex_select_nb(tsl_stop, init_i, c.col)) init_reverse = flex_select_nb(reverse, init_i, c.col) last_high = last_low = init_entry_price last_i = -1 for i in range(c.from_i - c.wait, c.to_i): # Resolve current bar _entry_price = flex_select_nb(entry_price, i, c.col) _open = flex_select_nb(open, i, c.col) _high = flex_select_nb(high, i, c.col) _low = flex_select_nb(low, i, c.col) _close = flex_select_nb(close, i, c.col) if np.isnan(_open) and not np.isnan(_entry_price) and is_entry_open: _open = _entry_price if np.isnan(_close) and not np.isnan(_entry_price) and not is_entry_open: _close = _entry_price if np.isnan(_high): if np.isnan(_open): _high = _close elif np.isnan(_close): _high = _open else: _high = max(_open, _close) if np.isnan(_low): if np.isnan(_open): _low = _close elif np.isnan(_close): _low = _open else: _low = min(_open, _close) if i > init_i or is_entry_open: curr_high = _high curr_low = _low else: curr_high = curr_low = _close if i >= c.from_i: # Calculate stop prices if not np.isnan(init_sl_stop): if init_reverse: curr_sl_stop_price = init_entry_price * (1 + init_sl_stop) else: curr_sl_stop_price = init_entry_price * (1 - init_sl_stop) if not np.isnan(init_tsl_stop): if np.isnan(init_tsl_th): if init_reverse: curr_tsl_stop_price = last_low * (1 + init_tsl_stop) else: curr_tsl_stop_price = last_high * (1 - init_tsl_stop) else: if init_reverse: if last_low <= init_entry_price * (1 - init_tsl_th): curr_tsl_stop_price = last_low * (1 + init_tsl_stop) else: curr_tsl_stop_price = np.nan else: if last_high >= init_entry_price * (1 + init_tsl_th): curr_tsl_stop_price = last_high * (1 - init_tsl_stop) else: curr_tsl_stop_price = np.nan if not np.isnan(init_tp_stop): if init_reverse: curr_tp_stop_price = init_entry_price * (1 - init_tp_stop) else: curr_tp_stop_price = init_entry_price * (1 + init_tp_stop) # Check if stop price is within bar exit_signal = False if not np.isnan(init_sl_stop): # SL hit? stop_price = np.nan if not init_reverse: if _open <= curr_sl_stop_price: stop_price = _open if curr_low <= curr_sl_stop_price: stop_price = curr_sl_stop_price else: if _open >= curr_sl_stop_price: stop_price = _open if curr_high >= curr_sl_stop_price: stop_price = curr_sl_stop_price if not np.isnan(stop_price): stop_price_out[i, c.col] = stop_price stop_type_out[i, c.col] = StopType.SL exit_signal = True if not exit_signal and not np.isnan(init_tsl_stop): # TSL/TTP hit? stop_price = np.nan if not init_reverse: if _open <= curr_tsl_stop_price: stop_price = _open if curr_low <= curr_tsl_stop_price: stop_price = curr_tsl_stop_price else: if _open >= curr_tsl_stop_price: stop_price = _open if curr_high >= curr_tsl_stop_price: stop_price = curr_tsl_stop_price if not np.isnan(stop_price): stop_price_out[i, c.col] = stop_price if np.isnan(init_tsl_th): stop_type_out[i, c.col] = StopType.TSL else: stop_type_out[i, c.col] = StopType.TTP exit_signal = True if not exit_signal and not np.isnan(init_tp_stop): # TP hit? stop_price = np.nan if not init_reverse: if _open >= curr_tp_stop_price: stop_price = _open if curr_high >= curr_tp_stop_price: stop_price = curr_tp_stop_price else: if _open <= curr_tp_stop_price: stop_price = _open if curr_low <= curr_tp_stop_price: stop_price = curr_tp_stop_price if not np.isnan(stop_price): stop_price_out[i, c.col] = stop_price stop_type_out[i, c.col] = StopType.TP exit_signal = True if exit_signal: c.out[i - c.from_i] = True last_i = i - c.from_i break if i > init_i or is_entry_open: # Keep track of the lowest low and the highest high if curr_low < last_low: last_low = curr_low if curr_high > last_high: last_high = curr_high return last_i # ############# Ranking ############# # @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict( mask=ch.ArraySlicer(axis=1), rank_func_nb=None, rank_args=ch.ArgsTaker(), reset_by=None, after_false=None, after_reset=None, reset_wait=None, ), merge_func="column_stack", ) @register_jitted(tags={"can_parallel"}) def rank_nb( mask: tp.Array2d, rank_func_nb: tp.RankFunc, rank_args: tp.Args = (), reset_by: tp.Optional[tp.Array2d] = None, after_false: bool = False, after_reset: bool = False, reset_wait: int = 1, ) -> tp.Array2d: """Rank each signal using `rank_func_nb`. Applies `rank_func_nb` on each True value. Must accept a context of type `vectorbtpro.signals.enums.RankContext`. Must return -1 for no rank, otherwise 0 or greater. Setting `after_false` to True will disregard the first partition of True values if there is no False value before them. Setting `after_reset` to True will disregard the first partition of True values coming before the first reset signal. Setting `reset_wait` to 0 will treat the signal at the same position as the reset signal as the first signal in the next partition. Setting it to 1 will treat it as the last signal in the previous partition.""" out = np.full(mask.shape, -1, dtype=int_) for col in prange(mask.shape[1]): in_partition = False false_seen = not after_false reset_seen = reset_by is None last_false_i = -1 last_reset_i = -1 all_sig_cnt = 0 all_part_cnt = 0 all_sig_in_part_cnt = 0 nonres_sig_cnt = 0 nonres_part_cnt = 0 nonres_sig_in_part_cnt = 0 sig_cnt = 0 part_cnt = 0 sig_in_part_cnt = 0 for i in range(mask.shape[0]): if reset_by is not None and reset_by[i, col]: last_reset_i = i if last_reset_i > -1 and i - last_reset_i == reset_wait: reset_seen = True sig_cnt = 0 part_cnt = 0 sig_in_part_cnt = 0 if mask[i, col]: all_sig_cnt += 1 if i == 0 or not mask[i - 1, col]: all_part_cnt += 1 all_sig_in_part_cnt += 1 if not (after_false and not false_seen) and not (after_reset and not reset_seen): nonres_sig_cnt += 1 sig_cnt += 1 if not in_partition: nonres_part_cnt += 1 part_cnt += 1 elif last_reset_i > -1 and i - last_reset_i == reset_wait: part_cnt += 1 nonres_sig_in_part_cnt += 1 sig_in_part_cnt += 1 in_partition = True c = RankContext( mask=mask, reset_by=reset_by, after_false=after_false, after_reset=after_reset, reset_wait=reset_wait, col=col, i=i, last_false_i=last_false_i, last_reset_i=last_reset_i, all_sig_cnt=all_sig_cnt, all_part_cnt=all_part_cnt, all_sig_in_part_cnt=all_sig_in_part_cnt, nonres_sig_cnt=nonres_sig_cnt, nonres_part_cnt=nonres_part_cnt, nonres_sig_in_part_cnt=nonres_sig_in_part_cnt, sig_cnt=sig_cnt, part_cnt=part_cnt, sig_in_part_cnt=sig_in_part_cnt, ) out[i, col] = rank_func_nb(c, *rank_args) else: all_sig_in_part_cnt = 0 nonres_sig_in_part_cnt = 0 sig_in_part_cnt = 0 last_false_i = i in_partition = False false_seen = True return out @register_jitted def sig_pos_rank_nb(c: RankContext, allow_gaps: bool) -> int: """`rank_func_nb` that returns the rank of each signal by its position in the partition if `allow_gaps` is False, otherwise globally. Resets at each reset signal.""" if allow_gaps: return c.sig_cnt - 1 return c.sig_in_part_cnt - 1 @register_jitted def part_pos_rank_nb(c: RankContext) -> int: """`rank_func_nb` that returns the rank of each partition by its position in the series. Resets at each reset signal.""" return c.part_cnt - 1 # ############# Distance ############# # @register_jitted(cache=True) def distance_from_last_1d_nb(mask: tp.Array1d, nth: int = 1) -> tp.Array1d: """Distance from the last n-th True value to the current value. Unless `nth` is zero, the current True value isn't counted as one of the last True values.""" if nth < 0: raise ValueError("nth must be at least 0") out = np.empty(mask.shape, dtype=int_) last_indices = np.empty(mask.shape, dtype=int_) k = 0 for i in range(mask.shape[0]): if nth == 0: if mask[i]: last_indices[k] = i k += 1 if k - 1 < 0: out[i] = -1 else: out[i] = i - last_indices[k - 1] elif nth == 1: if k - nth < 0: out[i] = -1 else: out[i] = i - last_indices[k - nth] if mask[i]: last_indices[k] = i k += 1 else: if mask[i]: last_indices[k] = i k += 1 if k - nth < 0: out[i] = -1 else: out[i] = i - last_indices[k - nth] return out @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict(mask=ch.ArraySlicer(axis=1), nth=None), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def distance_from_last_nb(mask: tp.Array2d, nth: int = 1) -> tp.Array2d: """2-dim version of `distance_from_last_1d_nb`.""" out = np.empty(mask.shape, dtype=int_) for col in prange(mask.shape[1]): out[:, col] = distance_from_last_1d_nb(mask[:, col], nth=nth) return out # ############# Cleaning ############# # @register_jitted(cache=True) def clean_enex_1d_nb( entries: tp.Array1d, exits: tp.Array1d, force_first: bool = True, keep_conflicts: bool = False, reverse_order: bool = False, ) -> tp.Tuple[tp.Array1d, tp.Array1d]: """Clean entry and exit arrays by picking the first signal out of each. Set `force_first` to True to force placing the first entry/exit before the first exit/entry. Set `keep_conflicts` to True to process signals at the same timestamp sequentially instead of removing them. Set `reverse_order` to True to reverse the order of signals.""" entries_out = np.full(entries.shape, False, dtype=np.bool_) exits_out = np.full(exits.shape, False, dtype=np.bool_) def _process_entry(i, phase): if ((not force_first or not reverse_order) and phase == -1) or phase == 1: phase = 0 entries_out[i] = True return phase def _process_exit(i, phase): if ((not force_first or reverse_order) and phase == -1) or phase == 0: phase = 1 exits_out[i] = True return phase phase = -1 for i in range(entries.shape[0]): if entries[i] and exits[i]: if keep_conflicts: if not reverse_order: phase = _process_entry(i, phase) phase = _process_exit(i, phase) else: phase = _process_exit(i, phase) phase = _process_entry(i, phase) elif entries[i]: phase = _process_entry(i, phase) elif exits[i]: phase = _process_exit(i, phase) return entries_out, exits_out @register_chunkable( size=ch.ArraySizer(arg_query="entries", axis=1), arg_take_spec=dict( entries=ch.ArraySlicer(axis=1), exits=ch.ArraySlicer(axis=1), force_first=None, keep_conflicts=None, reverse_order=None, ), merge_func="column_stack", ) @register_jitted(cache=True, tags={"can_parallel"}) def clean_enex_nb( entries: tp.Array2d, exits: tp.Array2d, force_first: bool = True, keep_conflicts: bool = False, reverse_order: bool = False, ) -> tp.Tuple[tp.Array2d, tp.Array2d]: """2-dim version of `clean_enex_1d_nb`.""" entries_out = np.empty(entries.shape, dtype=np.bool_) exits_out = np.empty(exits.shape, dtype=np.bool_) for col in prange(entries.shape[1]): entries_out[:, col], exits_out[:, col] = clean_enex_1d_nb( entries[:, col], exits[:, col], force_first=force_first, keep_conflicts=keep_conflicts, reverse_order=reverse_order, ) return entries_out, exits_out # ############# Relation ############# # @register_jitted(cache=True) def relation_idxs_1d_nb( source_mask: tp.Array1d, target_mask: tp.Array1d, relation: int = SignalRelation.OneMany, ) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]: """Get index pairs of True values between a source and target mask. For `relation`, see `vectorbtpro.signals.enums.SignalRelation`. !!! note If both True values happen at the same time, source signal is assumed to come first.""" if relation == SignalRelation.Chain or relation == SignalRelation.AnyChain: max_signals = source_mask.shape[0] * 2 else: max_signals = source_mask.shape[0] source_range_out = np.full(max_signals, -1, dtype=int_) target_range_out = np.full(max_signals, -1, dtype=int_) source_idxs_out = np.full(max_signals, -1, dtype=int_) target_idxs_out = np.full(max_signals, -1, dtype=int_) source_j = -1 target_j = -1 k = 0 if relation == SignalRelation.OneOne: fill_k = -1 for i in range(source_mask.shape[0]): if source_mask[i]: source_j += 1 if target_mask[i]: target_j += 1 if source_mask[i]: source_range_out[k] = source_j source_idxs_out[k] = i if fill_k == -1: fill_k = k k += 1 if target_mask[i]: if fill_k == -1: target_range_out[k] = target_j target_idxs_out[k] = i k += 1 else: target_range_out[fill_k] = target_j target_idxs_out[fill_k] = i fill_k += 1 if fill_k == k: fill_k = -1 elif relation == SignalRelation.OneMany: source_idx = -1 source_placed = True for i in range(source_mask.shape[0]): if source_mask[i]: source_j += 1 if target_mask[i]: target_j += 1 if source_mask[i]: if not source_placed: source_range_out[k] = source_j source_idxs_out[k] = source_idx k += 1 source_idx = i source_placed = False if target_mask[i]: if source_idx == -1: target_range_out[k] = target_j target_idxs_out[k] = i k += 1 else: source_range_out[k] = source_j target_range_out[k] = target_j source_idxs_out[k] = source_idx target_idxs_out[k] = i k += 1 source_placed = True if not source_placed: source_range_out[k] = source_j source_idxs_out[k] = source_idx k += 1 elif relation == SignalRelation.ManyOne: target_idx = -1 target_placed = True for i in range(source_mask.shape[0] - 1, -1, -1): if source_mask[i]: source_j += 1 if target_mask[i]: target_j += 1 if target_mask[i]: if not target_placed: target_range_out[k] = target_j target_idxs_out[k] = target_idx k += 1 target_idx = i target_placed = False if source_mask[i]: if target_idx == -1: source_range_out[k] = source_j source_idxs_out[k] = i k += 1 else: source_range_out[k] = source_j target_range_out[k] = target_j source_idxs_out[k] = i target_idxs_out[k] = target_idx k += 1 target_placed = True if not target_placed: target_range_out[k] = target_j target_idxs_out[k] = target_idx k += 1 source_range_out[:k] = source_range_out[:k][::-1] target_range_out[:k] = target_range_out[:k][::-1] for _k in range(k): source_range_out[_k] = source_j - source_range_out[_k] target_range_out[_k] = target_j - target_range_out[_k] source_idxs_out[:k] = source_idxs_out[:k][::-1] target_idxs_out[:k] = target_idxs_out[:k][::-1] elif relation == SignalRelation.ManyMany: source_idx = -1 from_k = -1 for i in range(source_mask.shape[0]): if source_mask[i]: source_j += 1 if target_mask[i]: target_j += 1 if source_mask[i]: source_idx = i source_range_out[k] = source_j source_idxs_out[k] = source_idx if from_k == -1: from_k = k k += 1 if target_mask[i]: if from_k == -1: if source_idx != -1: source_range_out[k] = source_j source_idxs_out[k] = source_idx target_range_out[k] = target_j target_idxs_out[k] = i k += 1 else: for _k in range(from_k, k): target_range_out[_k] = target_j target_idxs_out[_k] = i from_k = -1 elif relation == SignalRelation.Chain: source_idx = -1 target_idx = -1 any_placed = False for i in range(source_mask.shape[0]): if source_mask[i]: source_j += 1 if target_mask[i]: target_j += 1 if source_mask[i]: if source_idx == -1: source_idx = i source_range_out[k] = source_j source_idxs_out[k] = source_idx any_placed = True if target_idx != -1: target_idx = -1 k += 1 source_range_out[k] = source_j source_idxs_out[k] = source_idx if target_mask[i]: if target_idx == -1 and source_idx != -1: target_idx = i target_range_out[k] = target_j target_idxs_out[k] = target_idx source_idx = -1 k += 1 target_range_out[k] = target_j target_idxs_out[k] = target_idx any_placed = True if any_placed: k += 1 elif relation == SignalRelation.AnyChain: source_idx = -1 target_idx = -1 any_placed = False for i in range(source_mask.shape[0]): if source_mask[i]: source_j += 1 if target_mask[i]: target_j += 1 if source_mask[i]: if source_idx == -1: source_idx = i source_range_out[k] = source_j source_idxs_out[k] = source_idx any_placed = True if target_idx != -1: target_idx = -1 k += 1 source_range_out[k] = source_j source_idxs_out[k] = source_idx if target_mask[i]: if target_idx == -1: target_idx = i target_range_out[k] = target_j target_idxs_out[k] = target_idx any_placed = True if source_idx != -1: source_idx = -1 k += 1 target_range_out[k] = target_j target_idxs_out[k] = target_idx if any_placed: k += 1 else: raise ValueError("Invalid SignalRelation option") return source_range_out[:k], target_range_out[:k], source_idxs_out[:k], target_idxs_out[:k] # ############# Ranges ############# # @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict(mask=ch.ArraySlicer(axis=1), incl_open=None), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def between_ranges_nb(mask: tp.Array2d, incl_open: bool = False) -> tp.RecordArray: """Create a record of type `vectorbtpro.generic.enums.range_dt` for each range between two signals in `mask`.""" new_records = np.empty(mask.shape, dtype=range_dt) counts = np.full(mask.shape[1], 0, dtype=int_) for col in prange(mask.shape[1]): from_i = -1 for i in range(mask.shape[0]): if mask[i, col]: if from_i > -1: r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = from_i new_records["end_idx"][r, col] = i new_records["status"][r, col] = RangeStatus.Closed counts[col] += 1 from_i = i if incl_open and from_i < mask.shape[0] - 1: r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = from_i new_records["end_idx"][r, col] = mask.shape[0] - 1 new_records["status"][r, col] = RangeStatus.Open counts[col] += 1 return generic_nb.repartition_nb(new_records, counts) @register_chunkable( size=ch.ArraySizer(arg_query="source_mask", axis=1), arg_take_spec=dict( source_mask=ch.ArraySlicer(axis=1), target_mask=ch.ArraySlicer(axis=1), relation=None, incl_open=None, ), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def between_two_ranges_nb( source_mask: tp.Array2d, target_mask: tp.Array2d, relation: int = SignalRelation.OneMany, incl_open: bool = False, ) -> tp.RecordArray: """Create a record of type `vectorbtpro.generic.enums.range_dt` for each range between a source and target mask. Index pairs are resolved with `relation_idxs_1d_nb`.""" new_records = np.empty(source_mask.shape, dtype=range_dt) counts = np.full(source_mask.shape[1], 0, dtype=int_) for col in prange(source_mask.shape[1]): _, _, source_idxs, target_idsx = relation_idxs_1d_nb( source_mask[:, col], target_mask[:, col], relation=relation, ) for i in range(len(source_idxs)): if source_idxs[i] != -1 and target_idsx[i] != -1: r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = source_idxs[i] new_records["end_idx"][r, col] = target_idsx[i] new_records["status"][r, col] = RangeStatus.Closed counts[col] += 1 elif source_idxs[i] != -1 and target_idsx[i] == -1 and incl_open: r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = source_idxs[i] new_records["end_idx"][r, col] = source_mask.shape[0] - 1 new_records["status"][r, col] = RangeStatus.Open counts[col] += 1 return generic_nb.repartition_nb(new_records, counts) @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict(mask=ch.ArraySlicer(axis=1)), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def partition_ranges_nb(mask: tp.Array2d) -> tp.RecordArray: """Create a record of type `vectorbtpro.generic.enums.range_dt` for each partition of signals in `mask`.""" new_records = np.empty(mask.shape, dtype=range_dt) counts = np.full(mask.shape[1], 0, dtype=int_) for col in prange(mask.shape[1]): is_partition = False from_i = -1 for i in range(mask.shape[0]): if mask[i, col]: if not is_partition: from_i = i is_partition = True elif is_partition: to_i = i r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = from_i new_records["end_idx"][r, col] = to_i new_records["status"][r, col] = RangeStatus.Closed counts[col] += 1 is_partition = False if i == mask.shape[0] - 1: if is_partition: to_i = mask.shape[0] - 1 r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = from_i new_records["end_idx"][r, col] = to_i new_records["status"][r, col] = RangeStatus.Open counts[col] += 1 return generic_nb.repartition_nb(new_records, counts) @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict(mask=ch.ArraySlicer(axis=1)), merge_func=records_ch.merge_records, merge_kwargs=dict(chunk_meta=Rep("chunk_meta")), ) @register_jitted(cache=True, tags={"can_parallel"}) def between_partition_ranges_nb(mask: tp.Array2d) -> tp.RecordArray: """Create a record of type `vectorbtpro.generic.enums.range_dt` for each range between two partitions in `mask`.""" new_records = np.empty(mask.shape, dtype=range_dt) counts = np.full(mask.shape[1], 0, dtype=int_) for col in prange(mask.shape[1]): is_partition = False from_i = -1 for i in range(mask.shape[0]): if mask[i, col]: if not is_partition and from_i != -1: to_i = i r = counts[col] new_records["id"][r, col] = r new_records["col"][r, col] = col new_records["start_idx"][r, col] = from_i new_records["end_idx"][r, col] = to_i new_records["status"][r, col] = RangeStatus.Closed counts[col] += 1 is_partition = True from_i = i else: is_partition = False return generic_nb.repartition_nb(new_records, counts) # ############# Raveling ############# # @register_jitted(cache=True) def unravel_nb( mask: tp.Array2d, incl_empty_cols: bool = True, ) -> tp.Tuple[ tp.Array2d, tp.Array1d, tp.Array1d, tp.Array1d, ]: """Unravel each True value in a mask to a separate column. Returns the new mask, the index of each True value in its column, the row index of each True value in its column, and the column index of each True value in the original mask.""" true_idxs = np.flatnonzero(mask.transpose()) start_idxs = np.full(mask.shape[1], -1, dtype=int_) end_idxs = np.full(mask.shape[1], 0, dtype=int_) for i in range(len(true_idxs)): col = true_idxs[i] // mask.shape[0] if i == 0: prev_col = -1 else: prev_col = true_idxs[i - 1] // mask.shape[0] if col != prev_col: start_idxs[col] = i end_idxs[col] = i + 1 n_cols = (end_idxs - start_idxs).sum() new_mask = np.full((mask.shape[0], n_cols), False, dtype=np.bool_) range_ = np.full(n_cols, -1, dtype=int_) row_idxs = np.full(n_cols, -1, dtype=int_) col_idxs = np.empty(n_cols, dtype=int_) k = 0 for i in range(len(start_idxs)): start_idx = start_idxs[i] end_idx = end_idxs[i] col_filled = False if start_idx != -1: for j in range(start_idx, end_idx): new_mask[true_idxs[j] % mask.shape[0], k] = True range_[k] = j - start_idx row_idxs[k] = true_idxs[j] % mask.shape[0] col_idxs[k] = true_idxs[j] // mask.shape[0] k += 1 col_filled = True if not col_filled and incl_empty_cols: if k == 0: col_idxs[k] = 0 else: col_idxs[k] = col_idxs[k - 1] + 1 k += 1 return new_mask[:, :k], range_[:k], row_idxs[:k], col_idxs[:k] @register_jitted(cache=True) def unravel_between_nb( mask: tp.Array2d, incl_open_source: bool = False, incl_empty_cols: bool = True, ) -> tp.Tuple[ tp.Array2d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, ]: """Unravel each pair of successive True values in a mask to a separate column. Returns the new mask, the index of each source True value in its column, the index of each target True value in its column, the row index of each source True value in the original mask, the row index of each target True value in the original mask, and the column index of each True value in the original mask.""" true_idxs = np.flatnonzero(mask.transpose()) start_idxs = np.full(mask.shape[1], -1, dtype=int_) end_idxs = np.full(mask.shape[1], 0, dtype=int_) for i in range(len(true_idxs)): col = true_idxs[i] // mask.shape[0] if i == 0: prev_col = -1 else: prev_col = true_idxs[i - 1] // mask.shape[0] if col != prev_col: start_idxs[col] = i end_idxs[col] = i + 1 n_cols = (end_idxs - start_idxs).sum() new_mask = np.full((mask.shape[0], n_cols), False, dtype=np.bool_) source_range = np.full(n_cols, -1, dtype=int_) target_range = np.full(n_cols, -1, dtype=int_) source_idxs = np.full(n_cols, -1, dtype=int_) target_idxs = np.full(n_cols, -1, dtype=int_) col_idxs = np.empty(n_cols, dtype=int_) k = 0 for i in range(len(start_idxs)): start_idx = start_idxs[i] end_idx = end_idxs[i] col_filled = False if start_idx != -1: for j in range(start_idx, end_idx): if j == end_idx - 1 and not incl_open_source: continue new_mask[true_idxs[j] % mask.shape[0], k] = True source_range[k] = j - start_idx source_idxs[k] = true_idxs[j] % mask.shape[0] if j < end_idx - 1: new_mask[true_idxs[j + 1] % mask.shape[0], k] = True target_range[k] = j + 1 - start_idx target_idxs[k] = true_idxs[j + 1] % mask.shape[0] col_idxs[k] = true_idxs[j] // mask.shape[0] k += 1 col_filled = True if not col_filled and incl_empty_cols: if k == 0: col_idxs[k] = 0 else: col_idxs[k] = col_idxs[k - 1] + 1 k += 1 return ( new_mask[:, :k], source_range[:k], target_range[:k], source_idxs[:k], target_idxs[:k], col_idxs[:k], ) @register_jitted(cache=True) def unravel_between_two_nb( source_mask: tp.Array2d, target_mask: tp.Array2d, relation: int = SignalRelation.OneMany, incl_open_source: bool = False, incl_open_target: bool = False, incl_empty_cols: bool = True, ) -> tp.Tuple[ tp.Array2d, tp.Array2d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, ]: """Unravel each pair of successive True values between a source and target mask to a separate column. Index pairs are resolved with `relation_idxs_1d_nb`. Returns the new source mask, the new target mask, the index of each source True value in its column, the index of each target True value in its column, the row index of each True value in each original mask, and the column index of each True value in both original masks.""" if relation == SignalRelation.Chain or relation == SignalRelation.AnyChain: max_signals = source_mask.shape[0] * 2 else: max_signals = source_mask.shape[0] source_range_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_) target_range_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_) source_idxs_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_) target_idxs_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_) counts = np.empty(source_mask.shape[1], dtype=int_) n_cols = 0 for col in range(source_mask.shape[1]): source_range_col, target_range_col, source_idxs_col, target_idxs_col = relation_idxs_1d_nb( source_mask[:, col], target_mask[:, col], relation=relation, ) n_idxs = len(source_idxs_col) source_range_2d[:n_idxs, col] = source_range_col target_range_2d[:n_idxs, col] = target_range_col source_idxs_2d[:n_idxs, col] = source_idxs_col target_idxs_2d[:n_idxs, col] = target_idxs_col counts[col] = n_idxs if n_idxs == 0: n_cols += 1 else: n_cols += n_idxs new_source_mask = np.full((source_mask.shape[0], n_cols), False, dtype=np.bool_) new_target_mask = np.full((source_mask.shape[0], n_cols), False, dtype=np.bool_) source_range = np.full(n_cols, -1, dtype=int_) target_range = np.full(n_cols, -1, dtype=int_) source_idxs = np.full(n_cols, -1, dtype=int_) target_idxs = np.full(n_cols, -1, dtype=int_) col_idxs = np.empty(n_cols, dtype=int_) k = 0 for c in range(len(counts)): col_filled = False if counts[c] > 0: for j in range(counts[c]): source_idx = source_idxs_2d[j, c] target_idx = target_idxs_2d[j, c] if source_idx != -1 and target_idx != -1: new_source_mask[source_idx, k] = True new_target_mask[target_idx, k] = True source_range[k] = source_range_2d[j, c] target_range[k] = target_range_2d[j, c] source_idxs[k] = source_idx target_idxs[k] = target_idx col_idxs[k] = c k += 1 col_filled = True elif source_idx != -1 and incl_open_source: new_source_mask[source_idx, k] = True source_range[k] = source_range_2d[j, c] source_idxs[k] = source_idx col_idxs[k] = c k += 1 col_filled = True elif target_idx != -1 and incl_open_target: new_target_mask[target_idx, k] = True target_range[k] = target_range_2d[j, c] target_idxs[k] = target_idx col_idxs[k] = c k += 1 col_filled = True if not col_filled and incl_empty_cols: col_idxs[k] = c k += 1 return ( new_source_mask[:, :k], new_target_mask[:, :k], source_range[:k], target_range[:k], source_idxs[:k], target_idxs[:k], col_idxs[:k], ) @register_jitted(cache=True, tags={"can_parallel"}) def ravel_nb(mask: tp.Array2d, group_map: tp.GroupMap) -> tp.Array2d: """Ravel True values of each group into a separate column.""" group_idxs, group_lens = group_map group_start_idxs = np.cumsum(group_lens) - group_lens out = np.full((mask.shape[0], len(group_lens)), False, dtype=np.bool_) for group in prange(len(group_lens)): group_len = group_lens[group] start_idx = group_start_idxs[group] col_idxs = group_idxs[start_idx : start_idx + group_len] for col in col_idxs: for i in range(mask.shape[0]): if mask[i, col]: out[i, group] = True return out # ############# Index ############# # @register_jitted(cache=True) def nth_index_1d_nb(mask: tp.Array1d, n: int) -> int: """Get the index of the n-th True value. !!! note `n` starts with 0 and can be negative.""" if n >= 0: found = -1 for i in range(mask.shape[0]): if mask[i]: found += 1 if found == n: return i else: found = 0 for i in range(mask.shape[0] - 1, -1, -1): if mask[i]: found -= 1 if found == n: return i return -1 @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict(mask=ch.ArraySlicer(axis=1), n=None), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def nth_index_nb(mask: tp.Array2d, n: int) -> tp.Array1d: """2-dim version of `nth_index_1d_nb`.""" out = np.empty(mask.shape[1], dtype=int_) for col in prange(mask.shape[1]): out[col] = nth_index_1d_nb(mask[:, col], n) return out @register_jitted(cache=True) def norm_avg_index_1d_nb(mask: tp.Array1d) -> float: """Get mean index normalized to (-1, 1).""" mean_index = np.mean(np.flatnonzero(mask)) return rescale_nb(mean_index, (0, len(mask) - 1), (-1, 1)) @register_chunkable( size=ch.ArraySizer(arg_query="mask", axis=1), arg_take_spec=dict(mask=ch.ArraySlicer(axis=1)), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def norm_avg_index_nb(mask: tp.Array2d) -> tp.Array1d: """2-dim version of `norm_avg_index_1d_nb`.""" out = np.empty(mask.shape[1], dtype=float_) for col in prange(mask.shape[1]): out[col] = norm_avg_index_1d_nb(mask[:, col]) return out @register_chunkable( size=ch.ArraySizer(arg_query="group_lens", axis=0), arg_take_spec=dict( mask=base_ch.array_gl_slicer, group_lens=ch.ArraySlicer(axis=0), ), merge_func="concat", ) @register_jitted(cache=True, tags={"can_parallel"}) def norm_avg_index_grouped_nb(mask, group_lens): """Grouped version of `norm_avg_index_nb`.""" out = np.empty(len(group_lens), dtype=float_) group_end_idxs = np.cumsum(group_lens) group_start_idxs = group_end_idxs - group_lens for group in prange(len(group_lens)): from_col = group_start_idxs[group] to_col = group_end_idxs[group] temp_sum = 0 temp_cnt = 0 for col in range(from_col, to_col): for i in range(mask.shape[0]): if mask[i, col]: temp_sum += i temp_cnt += 1 out[group] = rescale_nb(temp_sum / temp_cnt, (0, mask.shape[0] - 1), (-1, 1)) return out { "data": { "histogram2dcontour": [ { "type": "histogram2dcontour", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "choropleth": [ { "type": "choropleth", "colorbar": { "outlinewidth": 0, "ticks": "" } } ], "histogram2d": [ { "type": "histogram2d", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "heatmap": [ { "type": "heatmap", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "heatmapgl": [ { "type": "heatmapgl", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "contourcarpet": [ { "type": "contourcarpet", "colorbar": { "outlinewidth": 0, "ticks": "" } } ], "contour": [ { "type": "contour", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "surface": [ { "type": "surface", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "mesh3d": [ { "type": "mesh3d", "colorbar": { "outlinewidth": 0, "ticks": "" } } ], "scatter": [ { "marker": { "line": { "color": "#313439" } }, "type": "scatter" } ], "parcoords": [ { "type": "parcoords", "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatterpolargl": [ { "type": "scatterpolargl", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "bar": [ { "error_x": { "color": "#d6dfef" }, "error_y": { "color": "#d6dfef" }, "marker": { "line": { "color": "#232428", "width": 0.5 } }, "type": "bar" } ], "scattergeo": [ { "type": "scattergeo", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatterpolar": [ { "type": "scatterpolar", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "histogram": [ { "type": "histogram", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scattergl": [ { "marker": { "line": { "color": "#313439" } }, "type": "scattergl" } ], "scatter3d": [ { "type": "scatter3d", "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scattermapbox": [ { "type": "scattermapbox", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatterternary": [ { "type": "scatterternary", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scattercarpet": [ { "type": "scattercarpet", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "carpet": [ { "aaxis": { "endlinecolor": "#A2B1C6", "gridcolor": "#313439", "linecolor": "#313439", "minorgridcolor": "#313439", "startlinecolor": "#A2B1C6" }, "baxis": { "endlinecolor": "#A2B1C6", "gridcolor": "#313439", "linecolor": "#313439", "minorgridcolor": "#313439", "startlinecolor": "#A2B1C6" }, "type": "carpet" } ], "table": [ { "cells": { "fill": { "color": "#313439" }, "line": { "color": "#232428" } }, "header": { "fill": { "color": "#2a3f5f" }, "line": { "color": "#232428" } }, "type": "table" } ], "barpolar": [ { "marker": { "line": { "color": "#232428", "width": 0.5 } }, "type": "barpolar" } ], "pie": [ { "automargin": true, "type": "pie" } ] }, "layout": { "colorway": [ "#1f77b4", "#ff7f0e", "#2ca02c", "#dc3912", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf" ], "font": { "color": "#d6dfef" }, "hovermode": "closest", "hoverlabel": { "align": "left" }, "paper_bgcolor": "#232428", "plot_bgcolor": "#232428", "polar": { "bgcolor": "#232428", "angularaxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "" }, "radialaxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "" } }, "ternary": { "bgcolor": "#232428", "aaxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "" }, "baxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "" }, "caxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "" } }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "sequential": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "sequentialminus": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ] }, "xaxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "#313439", "automargin": true, "zerolinewidth": 2 }, "yaxis": { "gridcolor": "#313439", "linecolor": "#313439", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "#313439", "automargin": true, "zerolinewidth": 2 }, "scene": { "xaxis": { "backgroundcolor": "#232428", "gridcolor": "#313439", "linecolor": "#313439", "showbackground": true, "ticks": "", "zerolinecolor": "#aec0d6", "gridwidth": 2 }, "yaxis": { "backgroundcolor": "#232428", "gridcolor": "#313439", "linecolor": "#313439", "showbackground": true, "ticks": "", "zerolinecolor": "#aec0d6", "gridwidth": 2 }, "zaxis": { "backgroundcolor": "#232428", "gridcolor": "#313439", "linecolor": "#313439", "showbackground": true, "ticks": "", "zerolinecolor": "#aec0d6", "gridwidth": 2 } }, "shapedefaults": { "line": { "color": "#d6dfef" } }, "annotationdefaults": { "arrowcolor": "#d6dfef", "arrowhead": 0, "arrowwidth": 1 }, "geo": { "bgcolor": "#232428", "landcolor": "#232428", "subunitcolor": "#313439", "showland": true, "showlakes": true, "lakecolor": "#232428" }, "title": { "x": 0.05 }, "updatemenudefaults": { "bgcolor": "#313439", "borderwidth": 0 }, "sliderdefaults": { "bgcolor": "#aec0d6", "borderwidth": 1, "bordercolor": "#232428", "tickwidth": 0 }, "mapbox": { "style": "dark" } } } { "data": { "histogram2dcontour": [ { "type": "histogram2dcontour", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "choropleth": [ { "type": "choropleth", "colorbar": { "outlinewidth": 0, "ticks": "" } } ], "histogram2d": [ { "type": "histogram2d", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "heatmap": [ { "type": "heatmap", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "heatmapgl": [ { "type": "heatmapgl", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "contourcarpet": [ { "type": "contourcarpet", "colorbar": { "outlinewidth": 0, "ticks": "" } } ], "contour": [ { "type": "contour", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "surface": [ { "type": "surface", "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ] } ], "mesh3d": [ { "type": "mesh3d", "colorbar": { "outlinewidth": 0, "ticks": "" } } ], "scatter": [ { "type": "scatter", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "parcoords": [ { "type": "parcoords", "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatterpolargl": [ { "type": "scatterpolargl", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "bar" } ], "scattergeo": [ { "type": "scattergeo", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatterpolar": [ { "type": "scatterpolar", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "histogram": [ { "type": "histogram", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scattergl": [ { "type": "scattergl", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatter3d": [ { "type": "scatter3d", "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scattermapbox": [ { "type": "scattermapbox", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scatterternary": [ { "type": "scatterternary", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "scattercarpet": [ { "type": "scattercarpet", "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } } } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 } }, "type": "barpolar" } ], "pie": [ { "automargin": true, "type": "pie" } ] }, "layout": { "colorway": [ "#1f77b4", "#ff7f0e", "#2ca02c", "#dc3912", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf" ], "font": { "color": "#2a3f5f" }, "hovermode": "closest", "hoverlabel": { "align": "left" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "bgcolor": "#E5ECF6", "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "ternary": { "bgcolor": "#E5ECF6", "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "sequential": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "sequentialminus": [ [ 0.0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1.0, "#f0f921" ] ], "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ] }, "xaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "automargin": true, "zerolinewidth": 2 }, "yaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "automargin": true, "zerolinewidth": 2 }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white", "gridwidth": 2 }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white", "gridwidth": 2 }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white", "gridwidth": 2 } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "geo": { "bgcolor": "white", "landcolor": "#E5ECF6", "subunitcolor": "white", "showland": true, "showlakes": true, "lakecolor": "white" }, "title": { "x": 0.05 }, "mapbox": { "style": "light" } } } { "layout": { "colorway": [ "rgb(76,114,176)", "rgb(221,132,82)", "rgb(85,168,104)", "rgb(196,78,82)", "rgb(129,114,179)", "rgb(147,120,96)", "rgb(218,139,195)", "rgb(140,140,140)", "rgb(204,185,116)", "rgb(100,181,205)" ], "font": { "color": "rgb(36,36,36)" }, "hovermode": "closest", "hoverlabel": { "align": "left" }, "paper_bgcolor": "white", "plot_bgcolor": "rgb(234,234,242)", "polar": { "bgcolor": "rgb(234,234,242)", "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "ternary": { "bgcolor": "rgb(234,234,242)", "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "sequential": [ [ 0.0, "rgb(2,4,25)" ], [ 0.06274509803921569, "rgb(24,15,41)" ], [ 0.12549019607843137, "rgb(47,23,57)" ], [ 0.18823529411764706, "rgb(71,28,72)" ], [ 0.25098039215686274, "rgb(97,30,82)" ], [ 0.3137254901960784, "rgb(123,30,89)" ], [ 0.3764705882352941, "rgb(150,27,91)" ], [ 0.4392156862745098, "rgb(177,22,88)" ], [ 0.5019607843137255, "rgb(203,26,79)" ], [ 0.5647058823529412, "rgb(223,47,67)" ], [ 0.6274509803921569, "rgb(236,76,61)" ], [ 0.6901960784313725, "rgb(242,107,73)" ], [ 0.7529411764705882, "rgb(244,135,95)" ], [ 0.8156862745098039, "rgb(245,162,122)" ], [ 0.8784313725490196, "rgb(246,188,153)" ], [ 0.9411764705882353, "rgb(247,212,187)" ], [ 1.0, "rgb(250,234,220)" ] ], "sequentialminus": [ [ 0.0, "rgb(2,4,25)" ], [ 0.06274509803921569, "rgb(24,15,41)" ], [ 0.12549019607843137, "rgb(47,23,57)" ], [ 0.18823529411764706, "rgb(71,28,72)" ], [ 0.25098039215686274, "rgb(97,30,82)" ], [ 0.3137254901960784, "rgb(123,30,89)" ], [ 0.3764705882352941, "rgb(150,27,91)" ], [ 0.4392156862745098, "rgb(177,22,88)" ], [ 0.5019607843137255, "rgb(203,26,79)" ], [ 0.5647058823529412, "rgb(223,47,67)" ], [ 0.6274509803921569, "rgb(236,76,61)" ], [ 0.6901960784313725, "rgb(242,107,73)" ], [ 0.7529411764705882, "rgb(244,135,95)" ], [ 0.8156862745098039, "rgb(245,162,122)" ], [ 0.8784313725490196, "rgb(246,188,153)" ], [ 0.9411764705882353, "rgb(247,212,187)" ], [ 1.0, "rgb(250,234,220)" ] ] }, "xaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "automargin": true, "zerolinewidth": 2 }, "yaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "automargin": true, "zerolinewidth": 2 }, "scene": { "xaxis": { "backgroundcolor": "rgb(234,234,242)", "gridcolor": "white", "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white", "gridwidth": 2 }, "yaxis": { "backgroundcolor": "rgb(234,234,242)", "gridcolor": "white", "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white", "gridwidth": 2 }, "zaxis": { "backgroundcolor": "rgb(234,234,242)", "gridcolor": "white", "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white", "gridwidth": 2 } }, "shapedefaults": { "line": { "color": "rgb(67,103,167)" } }, "annotationdefaults": { "arrowcolor": "rgb(67,103,167)", "arrowhead": 0, "arrowwidth": 1 }, "geo": { "bgcolor": "white", "landcolor": "rgb(234,234,242)", "subunitcolor": "white", "showland": true, "showlakes": true, "lakecolor": "white" }, "title": { "x": 0.05 }, "mapbox": { "style": "light" } } } # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Utilities for working with knowledge. Run for the examples: ```pycon >>> dataset = [ ... {"s": "ABC", "b": True, "d2": {"c": "red", "l": [1, 2]}}, ... {"s": "BCD", "b": True, "d2": {"c": "blue", "l": [3, 4]}}, ... {"s": "CDE", "b": False, "d2": {"c": "green", "l": [5, 6]}}, ... {"s": "DEF", "b": False, "d2": {"c": "yellow", "l": [7, 8]}}, ... {"s": "EFG", "b": False, "d2": {"c": "black", "l": [9, 10]}, "xyz": 123} ... ] >>> asset = vbt.KnowledgeAsset(dataset) ``` """ from typing import TYPE_CHECKING if TYPE_CHECKING: from vectorbtpro.utils.knowledge.asset_pipelines import * from vectorbtpro.utils.knowledge.base_asset_funcs import * from vectorbtpro.utils.knowledge.base_assets import * from vectorbtpro.utils.knowledge.chatting import * from vectorbtpro.utils.knowledge.custom_asset_funcs import * from vectorbtpro.utils.knowledge.custom_assets import * from vectorbtpro.utils.knowledge.formatting import * # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Asset pipeline classes. See `vectorbtpro.utils.knowledge` for the toy dataset.""" from vectorbtpro import _typing as tp from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import merge_dicts from vectorbtpro.utils.eval_ import evaluate from vectorbtpro.utils.execution import Task from vectorbtpro.utils.knowledge.base_asset_funcs import AssetFunc from vectorbtpro.utils.module_ import package_shortcut_config from vectorbtpro.utils.parsing import get_func_arg_names __all__ = [ "AssetPipeline", "BasicAssetPipeline", "ComplexAssetPipeline", ] class AssetPipeline(Base): """Abstract class representing an asset pipeline.""" @classmethod def resolve_task( cls, func: tp.AssetFuncLike, *args, prepare: bool = True, prepare_once: bool = True, cond_kwargs: tp.KwargsLike = None, asset_func_meta: tp.Union[None, dict, list] = None, **kwargs, ) -> tp.Task: """Resolve a task.""" if isinstance(func, tuple): func = Task.from_tuple(func) if isinstance(func, Task): args = (*func.args, *args) kwargs = merge_dicts(func.kwargs, kwargs) func = func.func if isinstance(func, str): from vectorbtpro.utils.knowledge import base_asset_funcs, custom_asset_funcs base_keys = dir(base_asset_funcs) custom_keys = dir(custom_asset_funcs) base_values = [getattr(base_asset_funcs, k) for k in base_keys] custom_values = [getattr(custom_asset_funcs, k) for k in custom_keys] module_items = dict(zip(base_keys + custom_keys, base_values + custom_values)) if ( func in module_items and isinstance(module_items[func], type) and issubclass(module_items[func], AssetFunc) ): func = module_items[func] elif func.title() + "AssetFunc" in module_items: func = module_items[func.title() + "AssetFunc"] else: found_func = False for k, v in module_items.items(): if isinstance(v, type) and issubclass(v, AssetFunc): if v._short_name is not None: if func.lower() == v._short_name.lower(): func = v found_func = True break if not found_func: raise ValueError(f"Function '{func}' not found") if isinstance(func, AssetFunc): raise TypeError("Function must be a subclass of AssetFunc, not an instance") if isinstance(func, type) and issubclass(func, AssetFunc): _asset_func_meta = {} for var_name, var_type in func.__annotations__.items(): if var_name.startswith("_") and tp.get_origin(var_type) is tp.ClassVar: _asset_func_meta[var_name] = getattr(func, var_name) if asset_func_meta is not None: if isinstance(asset_func_meta, dict): asset_func_meta.update(_asset_func_meta) else: asset_func_meta.append(_asset_func_meta) if prepare: if prepare_once: if cond_kwargs is None: cond_kwargs = {} if len(cond_kwargs) > 0: prepare_arg_names = get_func_arg_names(func.prepare) for k, v in cond_kwargs.items(): if k in prepare_arg_names: kwargs[k] = v args, kwargs = func.prepare(*args, **kwargs) func = func.call else: func = func.prepare_and_call else: func = func.call if not callable(func): raise TypeError("Function must be callable") return Task(func, *args, **kwargs) def run(self, d: tp.Any) -> tp.Any: """Run the pipeline on a data item.""" raise NotImplementedError def __call__(self, d: tp.Any) -> tp.Any: return self.run(d) class BasicAssetPipeline(AssetPipeline): """Class representing a basic asset pipeline. Builds a composite function out of all functions. Usage: ```pycon >>> asset_pipeline = vbt.BasicAssetPipeline() >>> asset_pipeline.append("flatten") >>> asset_pipeline.append("query", len) >>> asset_pipeline.append("get") >>> asset_pipeline(dataset[0]) 5 ``` """ def __init__(self, *args, **kwargs) -> None: if len(args) == 0: tasks = [] else: tasks = args[0] args = args[1:] if not isinstance(tasks, list): tasks = [tasks] self._tasks = [self.resolve_task(task, *args, **kwargs) for task in tasks] @property def tasks(self) -> tp.List[tp.Task]: """Tasks.""" return self._tasks def append(self, func: tp.AssetFuncLike, *args, **kwargs) -> None: """Append a task to the pipeline.""" self.tasks.append(self.resolve_task(func, *args, **kwargs)) @classmethod def compose_tasks(cls, tasks: tp.List[tp.Task]) -> tp.Callable: """Compose multiple tasks into one.""" def composed(d): result = d for func, args, kwargs in tasks: result = func(result, *args, **kwargs) return result return composed def run(self, d: tp.Any) -> tp.Any: return self.compose_tasks(list(self.tasks))(d) class ComplexAssetPipeline(AssetPipeline): """Class representing a complex asset pipeline. Takes an expression string and a context. Resolves functions inside the expression. Expression is evaluated with `vectorbtpro.utils.eval_.evaluate`. Usage: ```pycon >>> asset_pipeline = vbt.ComplexAssetPipeline("query(flatten(d), len)") >>> asset_pipeline(dataset[0]) 5 ``` """ @classmethod def resolve_expression_and_context( cls, expression: str, context: tp.KwargsLike = None, prepare: bool = True, prepare_once: bool = True, **resolve_task_kwargs, ) -> tp.Tuple[str, tp.Kwargs]: """Resolve an expression and a context. Parses an expression string, extracts function calls with their arguments, removing the first positional argument from each function, and creates a new context.""" import importlib import builtins import ast import sys if context is None: context = {} for k, v in package_shortcut_config.items(): if k not in context: try: context[k] = importlib.import_module(v) except ImportError: pass tree = ast.parse(expression) builtin_functions = set(dir(builtins)) imported_functions = set() imported_modules = set() defined_functions = set() func_context = {} class _FunctionAnalyzer(ast.NodeVisitor): def visit_Import(self, node): for alias in node.names: name = alias.asname if alias.asname else alias.name.split(".")[0] imported_modules.add(name) self.generic_visit(node) def visit_ImportFrom(self, node): for alias in node.names: name = alias.asname if alias.asname else alias.name imported_functions.add(name) self.generic_visit(node) def visit_FunctionDef(self, node): defined_functions.add(node.name) self.generic_visit(node) analyzer = _FunctionAnalyzer() analyzer.visit(tree) class _NodeMixin: def get_func_name(self, func): attrs = [] while isinstance(func, ast.Attribute): attrs.append(func.attr) func = func.value if isinstance(func, ast.Name): attrs.append(func.id) return ".".join(reversed(attrs)) if attrs else "" def is_function_assigned(self, func): func_name = self.get_func_name(func) if "." in func_name: func_name = func_name.split(".")[0] return ( func_name in context or func_name in builtin_functions or func_name in imported_functions or func_name in imported_modules or func_name in defined_functions ) class _FunctionCallVisitor(ast.NodeVisitor, _NodeMixin): def process_argument(self, arg): if isinstance(arg, ast.Constant): return arg.value elif isinstance(arg, ast.Name): var_name = arg.id if var_name in context: return context[var_name] elif var_name in builtin_functions: return getattr(builtins, var_name) else: raise ValueError(f"Variable '{var_name}' is not defined in the context") elif isinstance(arg, ast.List): return [self.process_argument(elem) for elem in arg.elts] elif isinstance(arg, ast.Tuple): return tuple(self.process_argument(elem) for elem in arg.elts) elif isinstance(arg, ast.Dict): return {self.process_argument(k): self.process_argument(v) for k, v in zip(arg.keys, arg.values)} elif isinstance(arg, ast.Set): return {self.process_argument(elem) for elem in arg.elts} elif isinstance(arg, ast.Call): if self.is_function_assigned(arg.func): return self.get_func_name(arg.func) raise ValueError(f"Unsupported or dynamic argument: {ast.dump(arg)}") def visit_Call(self, node): self.generic_visit(node) func_name = self.get_func_name(node.func) pos_args = [] for arg in node.args[1:]: arg_value = self.process_argument(arg) pos_args.append(arg_value) kw_args = {} for kw in node.keywords: if kw.arg is None: raise ValueError(f"Dynamic keyword argument names are not allowed in '{func_name}'") kw_name = kw.arg kw_value = self.process_argument(kw.value) kw_args[kw_name] = kw_value if not self.is_function_assigned(node.func): task = cls.resolve_task( func_name, *pos_args, **kw_args, prepare=prepare, prepare_once=prepare_once, **resolve_task_kwargs, ) if prepare and prepare_once: def func(d, _task=task): return _task.func(d, *_task.args, **_task.kwargs) else: func = task.func func_context[func_name] = func visitor = _FunctionCallVisitor() visitor.visit(tree) if prepare and prepare_once: class _ArgumentPruner(ast.NodeTransformer, _NodeMixin): def visit_Call(self, node: ast.Call): if not self.is_function_assigned(node.func): if node.args: node.args = [node.args[0]] else: node.args = [] node.keywords = [] self.generic_visit(node) return node pruner = _ArgumentPruner() modified_tree = pruner.visit(tree) ast.fix_missing_locations(modified_tree) if sys.version_info >= (3, 9): new_expression = ast.unparse(modified_tree) else: import astor new_expression = astor.to_source(modified_tree).strip() else: new_expression = expression new_context = merge_dicts(func_context, context) return new_expression, new_context def __init__( self, expression: str, context: tp.KwargsLike = None, prepare_once: bool = True, **resolve_task_kwargs, ) -> None: self._expression, self._context = self.resolve_expression_and_context( expression, context=context, prepare_once=prepare_once, **resolve_task_kwargs, ) @property def expression(self) -> str: """Expression.""" return self._expression @property def context(self) -> tp.Kwargs: """Context.""" return self._context def run(self, d: tp.Any) -> tp.Any: """Run the pipeline on a data item.""" context = merge_dicts({"d": d, "x": d}, self.context) return evaluate(self.expression, context=context) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base asset function classes.""" import attr from vectorbtpro import _typing as tp from vectorbtpro.utils import checks, search_ from vectorbtpro.utils.attr_ import MISSING from vectorbtpro.utils.base import Base from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts from vectorbtpro.utils.config import reorder_dict, reorder_list from vectorbtpro.utils.execution import NoResult from vectorbtpro.utils.formatting import dump from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.template import CustomTemplate, RepEval, RepFunc, substitute_templates __all__ = [ "AssetFunc", ] class AssetFunc(Base): """Abstract class representing an asset function.""" _short_name: tp.ClassVar[tp.Optional[str]] = None """Short name of the function to be used in expressions.""" _wrap: tp.ClassVar[tp.Optional[str]] = None """Whether the results are meant to be wrapped with `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset`.""" @classmethod def prepare(cls, *args, **kwargs) -> tp.ArgsKwargs: """Prepare positional and keyword arguments.""" return args, kwargs @classmethod def call(cls, d: tp.Any, *args, **kwargs) -> tp.Any: """Call the function.""" raise NotImplementedError @classmethod def prepare_and_call(cls, d: tp.Any, *args, **kwargs): """Prepare arguments and call the function.""" args, kwargs = cls.prepare(*args, **kwargs) return cls.call(d, *args, **kwargs) class GetAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.get`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "get" _wrap: tp.ClassVar[tp.Optional[str]] = False @classmethod def prepare( cls, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, source: tp.Optional[tp.CustomTemplateLike] = None, template_context: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset keep_path = asset_cls.resolve_setting(keep_path, "keep_path") skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) if path is not None: if isinstance(path, list): path = [search_.resolve_pathlike_key(p) for p in path] else: path = search_.resolve_pathlike_key(path) if source is not None: if isinstance(source, str): source = RepEval(source) elif checks.is_function(source): if checks.is_builtin_func(source): source = RepFunc(lambda _source=source: _source) else: source = RepFunc(source) elif not isinstance(source, CustomTemplate): raise TypeError(f"Source must be a string, function, or template, not {type(source)}") return (), { **dict( path=path, keep_path=keep_path, skip_missing=skip_missing, source=source, template_context=template_context, ), **kwargs, } @classmethod def call( cls, d: tp.Any, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, keep_path: bool = False, skip_missing: bool = False, source: tp.Optional[tp.CustomTemplate] = None, template_context: tp.KwargsLike = None, ) -> tp.Any: x = d if path is not None: if isinstance(path, list): xs = [] for p in path: try: xs.append(search_.get_pathlike_key(x, p, keep_path=True)) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue if len(xs) == 0: return NoResult x = merge_dicts(*xs) else: try: x = search_.get_pathlike_key(x, path, keep_path=keep_path) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e return NoResult if source is not None: _template_context = flat_merge_dicts( { "d": d, "x": x, **(x if isinstance(x, dict) else {}), }, template_context, ) new_d = source.substitute(_template_context, eval_id="source") if checks.is_function(new_d): new_d = new_d(x) else: new_d = x return new_d class SetAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.set`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "set" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, value: tp.Any, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) if checks.is_function(value): if checks.is_builtin_func(value): value = RepFunc(lambda _value=value: _value) else: value = RepFunc(value) if path is not None: if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] else: paths = [None] return (), { **dict( value=value, paths=paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, template_context=template_context, ), **kwargs, } @classmethod def call( cls, d: tp.Any, value: tp.Any, paths: tp.List[tp.PathLikeKey], skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, template_context: tp.KwargsLike = None, ) -> tp.Any: prev_keys = [] for p in paths: x = d if p is not None: try: x = search_.get_pathlike_key(x, p[:-1]) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue _template_context = flat_merge_dicts( { "d": d, "x": x, **(x if isinstance(x, dict) else {}), }, template_context, ) v = value.substitute(_template_context, eval_id="value") if checks.is_function(v): v = v(x) d = search_.set_pathlike_key(d, p, v, make_copy=make_copy, prev_keys=prev_keys) if not changed_only or len(prev_keys) > 0: return d return NoResult class RemoveAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.remove`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "remove" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] return (), { **dict( paths=paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, ), **kwargs, } @classmethod def call( cls, d: tp.Any, paths: tp.List[tp.PathLikeKey], skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, ) -> tp.Any: prev_keys = [] for p in paths: try: d = search_.remove_pathlike_key(d, p, make_copy=make_copy, prev_keys=prev_keys) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue if not changed_only or len(prev_keys) > 0: return d return NoResult class MoveAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.move`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "move" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, path: tp.Union[tp.PathMoveDict, tp.MaybeList[tp.PathLikeKey]], new_path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if new_path is None: checks.assert_instance_of(path, dict, arg_name="path") new_path = list(path.values()) path = list(path.keys()) if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] if isinstance(new_path, list): new_paths = [search_.resolve_pathlike_key(p) for p in new_path] else: new_paths = [search_.resolve_pathlike_key(new_path)] if len(paths) != len(new_paths): raise ValueError("Number of new paths must match number of paths") return (), { **dict( paths=paths, new_paths=new_paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, ), **kwargs, } @classmethod def call( cls, d: tp.Any, paths: tp.List[tp.PathLikeKey], new_paths: tp.List[tp.PathLikeKey], skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, ) -> tp.Any: prev_keys = [] for i, p in enumerate(paths): try: x = search_.get_pathlike_key(d, p) d = search_.remove_pathlike_key(d, p, make_copy=make_copy, prev_keys=prev_keys) d = search_.set_pathlike_key(d, new_paths[i], x, make_copy=make_copy, prev_keys=prev_keys) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue if not changed_only or len(prev_keys) > 0: return d return NoResult class RenameAssetFunc(MoveAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.rename`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "rename" @classmethod def prepare( cls, path: tp.Union[tp.PathRenameDict, tp.MaybeList[tp.PathLikeKey]], new_token: tp.Optional[tp.MaybeList[tp.PathKeyToken]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if new_token is None: checks.assert_instance_of(path, dict, arg_name="path") new_token = list(path.values()) path = list(path.keys()) if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] if isinstance(new_token, list): new_tokens = [search_.resolve_pathlike_key(t) for t in new_token] else: new_tokens = [search_.resolve_pathlike_key(new_token)] if len(paths) != len(new_tokens): raise ValueError("Number of new tokens must match number of paths") new_paths = [] for i in range(len(paths)): if len(new_tokens[i]) != 1: raise ValueError("Exactly one token must be provided for each path") new_paths.append(paths[i][:-1] + new_tokens[i]) return (), { **dict( paths=paths, new_paths=new_paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, ), **kwargs, } class ReorderAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.reorder`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "reorder" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, new_order: tp.Union[str, tp.PathKeyTokens], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) if isinstance(new_order, str): if new_order.lower() in ("asc", "ascending"): new_order = lambda x: ( sorted(x) if isinstance(x, dict) else sorted( range(len(x)), key=x.__getitem__, ) ) elif new_order.lower() in ("desc", "descending"): new_order = lambda x: ( sorted(x) if isinstance(x, dict) else sorted( range(len(x)), key=x.__getitem__, reverse=True, ) ) if isinstance(new_order, str): new_order = RepEval(new_order) elif checks.is_function(new_order): if checks.is_builtin_func(new_order): new_order = RepFunc(lambda _new_order=new_order: _new_order) else: new_order = RepFunc(new_order) if path is not None: if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] else: paths = [None] return (), { **dict( new_order=new_order, paths=paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, template_context=template_context, ), **kwargs, } @classmethod def call( cls, d: tp.Any, new_order: tp.Union[tp.PathKeyTokens, tp.CustomTemplate], paths: tp.List[tp.PathLikeKey], skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, template_context: tp.KwargsLike = None, ) -> tp.Any: prev_keys = [] for p in paths: x = d if p is not None: try: x = search_.get_pathlike_key(x, p) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue if isinstance(new_order, CustomTemplate): _template_context = flat_merge_dicts( { "d": d, "x": x, **(x if isinstance(x, dict) else {}), }, template_context, ) _new_order = new_order.substitute(_template_context, eval_id="new_order") if checks.is_function(_new_order): _new_order = _new_order(x) else: _new_order = new_order if isinstance(x, dict): x = reorder_dict(x, _new_order, skip_missing=skip_missing) else: if checks.is_namedtuple(x): x = type(x)(*reorder_list(x, _new_order, skip_missing=skip_missing)) else: x = type(x)(reorder_list(x, _new_order, skip_missing=skip_missing)) d = search_.set_pathlike_key(d, p, x, make_copy=make_copy, prev_keys=prev_keys) if not changed_only or len(prev_keys) > 0: return d return NoResult class QueryAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.query`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "query" _wrap: tp.ClassVar[tp.Optional[str]] = False @classmethod def prepare( cls, expression: tp.CustomTemplateLike, template_context: tp.KwargsLike = None, return_type: tp.Optional[str] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) return_type = asset_cls.resolve_setting(return_type, "return_type") if isinstance(expression, str): expression = RepEval(expression) elif checks.is_function(expression): if checks.is_builtin_func(expression): expression = RepFunc(lambda _expression=expression: _expression) else: expression = RepFunc(expression) elif not isinstance(expression, CustomTemplate): raise TypeError(f"Expression must be a string, function, or template, not {type(expression)}") return (), { **dict( expression=expression, template_context=template_context, return_type=return_type, ), **kwargs, } @classmethod def call( cls, d: tp.Any, expression: tp.CustomTemplate, template_context: tp.KwargsLike = None, return_type: str = "item", ) -> tp.Any: _template_context = flat_merge_dicts( { "d": d, "x": d, **search_.search_config, **(d if isinstance(d, dict) else {}), }, template_context, ) new_d = expression.substitute(_template_context, eval_id="expression") if checks.is_function(new_d): new_d = new_d(d) if return_type.lower() == "item": as_filter = True elif return_type.lower() == "bool": as_filter = False else: raise ValueError(f"Invalid return type: '{return_type}'") if as_filter and isinstance(new_d, bool): if new_d: return d return NoResult return new_d class FindAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.find`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "find" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, target: tp.MaybeList[tp.Any], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, per_path: tp.Optional[bool] = None, find_all: tp.Optional[bool] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, source: tp.Optional[tp.CustomTemplateLike] = None, in_dumps: tp.Optional[bool] = None, dump_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, return_type: tp.Optional[str] = None, return_path: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset per_path = asset_cls.resolve_setting(per_path, "per_path") find_all = asset_cls.resolve_setting(find_all, "find_all") keep_path = asset_cls.resolve_setting(keep_path, "keep_path") skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") in_dumps = asset_cls.resolve_setting(in_dumps, "in_dumps") dump_kwargs = asset_cls.resolve_setting(dump_kwargs, "dump_kwargs", merge=True) template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) return_type = asset_cls.resolve_setting(return_type, "return_type") return_path = asset_cls.resolve_setting(return_path, "return_path") if path is not None: if isinstance(path, list): path = [search_.resolve_pathlike_key(p) for p in path] else: path = search_.resolve_pathlike_key(path) if per_path: if not isinstance(target, list): target = [target] if isinstance(path, list): target *= len(path) if not isinstance(path, list): path = [path] if isinstance(target, list): path *= len(target) if len(target) != len(path): raise ValueError("Number of targets must match number of paths") if source is not None: if isinstance(source, str): source = RepEval(source) elif checks.is_function(source): if checks.is_builtin_func(source): source = RepFunc(lambda _source=source: _source) else: source = RepFunc(source) elif not isinstance(source, CustomTemplate): raise TypeError(f"Source must be a string, function, or template, not {type(source)}") dump_kwargs = DumpAssetFunc.resolve_dump_kwargs(**dump_kwargs) contains_arg_names = set(get_func_arg_names(search_.contains_in_obj)) search_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in contains_arg_names} if "excl_types" not in search_kwargs: search_kwargs["excl_types"] = (tuple, set, frozenset) return (), { **dict( target=target, path=path, per_path=per_path, find_all=find_all, keep_path=keep_path, skip_missing=skip_missing, source=source, in_dumps=in_dumps, dump_kwargs=dump_kwargs, search_kwargs=search_kwargs, template_context=template_context, return_type=return_type, return_path=return_path, ), **kwargs, } @classmethod def match_func( cls, k: tp.Optional[tp.Hashable], d: tp.Any, target: tp.MaybeList[tp.Any], find_all: bool = False, **kwargs, ) -> bool: """Match function for `FindAssetFunc.call`. Uses `vectorbtpro.utils.search_.find` with `return_type="bool"` for text, and equality checks for other types. Target can be a function taking the value and returning a boolean. Target can also be an instance of `vectorbtpro.utils.search_.Not` for negation.""" if not isinstance(target, list): targets = [target] else: targets = target for target in targets: if isinstance(target, search_.Not): target = target.value negation = True else: negation = False if checks.is_function(target): if target(d): if (negation and find_all) or (not negation and not find_all): return not negation continue elif d is target: if (negation and find_all) or (not negation and not find_all): return not negation continue elif d is None and target is None: if (negation and find_all) or (not negation and not find_all): return not negation continue elif checks.is_bool(d) and checks.is_bool(target): if d == target: if (negation and find_all) or (not negation and not find_all): return not negation continue elif checks.is_number(d) and checks.is_number(target): if d == target: if (negation and find_all) or (not negation and not find_all): return not negation continue elif isinstance(d, str) and isinstance(target, str): if search_.find(target, d, return_type="bool", **kwargs): if (negation and find_all) or (not negation and not find_all): return not negation continue elif type(d) is type(target): try: if d == target: if (negation and find_all) or (not negation and not find_all): return not negation continue except Exception: pass if (negation and not find_all) or (not negation and find_all): return negation if find_all: return True return False @classmethod def call( cls, d: tp.Any, target: tp.MaybeList[tp.Any], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, per_path: bool = True, find_all: bool = False, keep_path: bool = False, skip_missing: bool = False, source: tp.Optional[tp.CustomTemplate] = None, in_dumps: bool = False, dump_kwargs: tp.KwargsLike = None, search_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, return_type: str = "item", return_path: bool = False, **kwargs, ) -> tp.Any: if dump_kwargs is None: dump_kwargs = {} if search_kwargs is None: search_kwargs = {} if per_path: new_path_dct = {} new_list = [] for i, p in enumerate(path): x = d try: x = search_.get_pathlike_key(x, p, keep_path=keep_path) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue if source is not None: _template_context = flat_merge_dicts( { "d": d, "x": x, **(x if isinstance(x, dict) else {}), }, template_context, ) _x = source.substitute(_template_context, eval_id="source") if checks.is_function(_x): x = _x(x) else: x = _x if not isinstance(x, str) and in_dumps: x = dump(x, **dump_kwargs) t = target[i] if return_type.lower() in ("item", "bool"): if isinstance(t, search_.Not): t = t.value negation = True else: negation = False if search_.contains_in_obj( x, cls.match_func, target=t, find_all=find_all, **search_kwargs, **kwargs, ): if negation: if find_all: return NoResult if return_type.lower() == "item" else False continue else: if not find_all: return d if return_type.lower() == "item" else True continue else: if negation: if not find_all: return d if return_type.lower() == "item" else True continue else: if find_all: return NoResult if return_type.lower() == "item" else False continue else: path_dct = search_.find_in_obj( x, cls.match_func, target=t, find_all=find_all, **search_kwargs, **kwargs, ) if len(path_dct) == 0: if find_all: return {} if return_path else [] continue if isinstance(t, search_.Not): raise TypeError("Target cannot be negated here") if not isinstance(t, str): raise ValueError("Target must be string") for k, v in path_dct.items(): if not isinstance(v, str): raise ValueError("Matched value must be string") _return_type = "bool" if return_type.lower() == "field" else return_type matches = search_.find(t, v, return_type=_return_type, **kwargs) if return_path: if k not in new_path_dct: new_path_dct[k] = [] if return_type.lower() == "field": if matches: new_path_dct[k].append(v) else: new_path_dct[k].extend(matches) else: if return_type.lower() == "field": if matches: new_list.append(v) else: new_list.extend(matches) if return_type.lower() in ("item", "bool"): if find_all: return d if return_type.lower() == "item" else True return NoResult if return_type.lower() == "item" else False else: if return_path: return new_path_dct return new_list else: x = d if path is not None: if isinstance(path, list): xs = [] for p in path: try: xs.append(search_.get_pathlike_key(x, p, keep_path=True)) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue if len(xs) == 0: if return_type.lower() == "item": return NoResult if return_type.lower() == "bool": return False return {} if return_path else [] x = merge_dicts(*xs) else: try: x = search_.get_pathlike_key(x, path, keep_path=keep_path) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e if return_type.lower() == "item": return NoResult if return_type.lower() == "bool": return False return {} if return_path else [] if source is not None: _template_context = flat_merge_dicts( { "d": d, "x": x, **(x if isinstance(x, dict) else {}), }, template_context, ) _x = source.substitute(_template_context, eval_id="source") if checks.is_function(_x): x = _x(x) else: x = _x if not isinstance(x, str) and in_dumps: x = dump(x, **dump_kwargs) if return_type.lower() == "item": if search_.contains_in_obj( x, cls.match_func, target=target, find_all=find_all, **search_kwargs, **kwargs, ): return d return NoResult elif return_type.lower() == "bool": return search_.contains_in_obj( x, cls.match_func, target=target, find_all=find_all, **search_kwargs, **kwargs, ) else: path_dct = search_.find_in_obj( x, cls.match_func, target=target, find_all=find_all, **search_kwargs, **kwargs, ) if len(path_dct) == 0: return {} if return_path else [] if not isinstance(target, list): targets = [target] else: targets = target new_path_dct = {} new_list = [] for target in targets: if isinstance(target, search_.Not): raise TypeError("Target cannot be negated here") if not isinstance(target, str): raise ValueError("Target must be string") for k, v in path_dct.items(): if not isinstance(v, str): raise ValueError("Matched value must be string") _return_type = "bool" if return_type.lower() == "field" else return_type matches = search_.find(target, v, return_type=_return_type, **kwargs) if return_path: if k not in new_path_dct: new_path_dct[k] = [] if return_type.lower() == "field": if matches: new_path_dct[k].append(v) else: new_path_dct[k].extend(matches) else: if return_type.lower() == "field": if matches: new_list.append(v) else: new_list.extend(matches) if return_path: return new_path_dct return new_list class FindReplaceAssetFunc(FindAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.find_replace`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "find_replace" @classmethod def prepare( cls, target: tp.Union[dict, tp.MaybeList[tp.Any]], replacement: tp.Optional[tp.MaybeList[tp.Any]] = None, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, per_path: tp.Optional[bool] = None, find_all: tp.Optional[bool] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset per_path = asset_cls.resolve_setting(per_path, "per_path") find_all = asset_cls.resolve_setting(find_all, "find_all") keep_path = asset_cls.resolve_setting(keep_path, "keep_path") skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if replacement is None: checks.assert_instance_of(target, dict, arg_name="path") replacement = list(target.values()) target = list(target.keys()) if path is not None: if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] if isinstance(target, list): paths *= len(target) elif isinstance(replacement, list): paths *= len(replacement) else: paths = [None] if isinstance(target, list): paths *= len(target) elif isinstance(replacement, list): paths *= len(replacement) if per_path: if not isinstance(target, list): target = [target] * len(paths) if not isinstance(replacement, list): replacement = [replacement] * len(paths) if len(target) != len(replacement) != len(paths): raise ValueError("Number of targets and replacements must match number of paths") find_arg_names = set(get_func_arg_names(search_.find_in_obj)) find_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in find_arg_names} if "excl_types" not in find_kwargs: find_kwargs["excl_types"] = (tuple, set, frozenset) return (), { **dict( target=target, replacement=replacement, paths=paths, per_path=per_path, find_all=find_all, keep_path=keep_path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, find_kwargs=find_kwargs, ), **kwargs, } @classmethod def replace_func( cls, k: tp.Optional[tp.Hashable], d: tp.Any, target: tp.MaybeList[tp.Any], replacement: tp.MaybeList[tp.Any], **kwargs, ) -> tp.Any: """Replace function for `FindReplaceAssetFunc.call`. Uses `vectorbtpro.utils.search_.replace` for text and returns replacement for other types. Target can be a function taking the value and returning a boolean. Replacement can also be a function taking the value and returning a new value.""" if not isinstance(target, list): targets = [target] else: targets = target if not isinstance(replacement, list): replacements = [replacement] if len(targets) > 1 and len(replacements) == 1: replacements *= len(targets) else: replacements = replacement if len(targets) != len(replacements): raise ValueError("Number of targets must match number of replacements") for i, target in enumerate(targets): if isinstance(target, search_.Not): raise TypeError("Target cannot be negated here") replacement = replacements[i] if checks.is_function(replacement): replacement = replacement(d) if checks.is_function(target): if target(d): return replacement elif d is target: return replacement elif d is None and target is None: return replacement elif checks.is_bool(d) and checks.is_bool(target): if d == target: return replacement elif checks.is_number(d) and checks.is_number(target): if d == target: return replacement elif isinstance(d, str) and isinstance(target, str): d = search_.replace(target, replacement, d, **kwargs) elif type(d) is type(target): try: if d == target: return replacement except Exception: pass return d @classmethod def call( cls, d: tp.Any, target: tp.MaybeList[tp.Any], replacement: tp.MaybeList[tp.Any], paths: tp.List[tp.PathLikeKey], per_path: bool = True, find_all: bool = False, keep_path: bool = False, skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, find_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Any: if find_kwargs is None: find_kwargs = {} prev_keys = [] found_all = True if find_all: for i, p in enumerate(paths): x = d if p is not None: try: x = search_.get_pathlike_key(x, p, keep_path=keep_path) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue path_dct = search_.find_in_obj( x, cls.match_func, target=target[i] if per_path else target, find_all=find_all, **find_kwargs, **kwargs, ) if len(path_dct) == 0: found_all = False break if found_all: for i, p in enumerate(paths): x = d if p is not None: try: x = search_.get_pathlike_key(x, p, keep_path=keep_path) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue path_dct = search_.find_in_obj( x, cls.match_func, target=target[i] if per_path else target, find_all=find_all, **find_kwargs, **kwargs, ) for k, v in path_dct.items(): if p is not None and not keep_path: new_p = search_.combine_pathlike_keys(p, k, minimize=True) else: new_p = k v = cls.replace_func( k, v, target[i] if per_path else target, replacement[i] if per_path else replacement, **kwargs, ) d = search_.set_pathlike_key(d, new_p, v, make_copy=make_copy, prev_keys=prev_keys) if not changed_only or len(prev_keys) > 0: return d return NoResult class FindRemoveAssetFunc(FindAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.find_remove`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "find_remove" @classmethod def prepare( cls, target: tp.Union[dict, tp.MaybeList[tp.Any]], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, per_path: tp.Optional[bool] = None, find_all: tp.Optional[bool] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset per_path = asset_cls.resolve_setting(per_path, "per_path") find_all = asset_cls.resolve_setting(find_all, "find_all") keep_path = asset_cls.resolve_setting(keep_path, "keep_path") skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if path is not None: if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] if isinstance(target, list): paths *= len(target) else: paths = [None] if isinstance(target, list): paths *= len(target) if per_path: if not isinstance(target, list): target = [target] * len(paths) if len(target) != len(paths): raise ValueError("Number of targets must match number of paths") find_arg_names = set(get_func_arg_names(search_.find_in_obj)) find_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in find_arg_names} if "excl_types" not in find_kwargs: find_kwargs["excl_types"] = (tuple, set, frozenset) return (), { **dict( target=target, paths=paths, per_path=per_path, find_all=find_all, keep_path=keep_path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, find_kwargs=find_kwargs, ), **kwargs, } @classmethod def is_empty_func(cls, d: tp.Any) -> bool: """Return whether object is empty.""" if d is None: return True if checks.is_collection(d) and len(d) == 0: return True return False @classmethod def call( cls, d: tp.Any, target: tp.MaybeList[tp.Any], paths: tp.List[tp.PathLikeKey], per_path: bool = True, find_all: bool = False, keep_path: bool = False, skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, find_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.Any: if find_kwargs is None: find_kwargs = {} prev_keys = [] new_p_v_map = {} for i, p in enumerate(paths): x = d if p is not None: try: x = search_.get_pathlike_key(x, p, keep_path=keep_path) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue path_dct = search_.find_in_obj( x, cls.match_func, target=target[i] if per_path else target, find_all=find_all, **find_kwargs, **kwargs, ) if len(path_dct) == 0 and find_all: new_p_v_map = {} break for k, v in path_dct.items(): if p is not None and not keep_path: new_p = search_.combine_pathlike_keys(p, k, minimize=True) else: new_p = k new_p_v_map[new_p] = v for new_p, v in new_p_v_map.items(): d = search_.remove_pathlike_key(d, new_p, make_copy=make_copy, prev_keys=prev_keys) if not changed_only or len(prev_keys) > 0: return d return NoResult class FlattenAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.flatten`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "flatten" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if path is not None: if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] else: paths = [None] if "excl_types" not in kwargs: kwargs["excl_types"] = (tuple, set, frozenset) return (), { **dict( paths=paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, ), **kwargs, } @classmethod def call( cls, d: tp.Any, paths: tp.List[tp.PathLikeKey], skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, **kwargs, ) -> tp.Any: prev_keys = [] for p in paths: x = d if p is not None: try: x = search_.get_pathlike_key(x, p) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue x = search_.flatten_obj(x, **kwargs) d = search_.set_pathlike_key(d, p, x, make_copy=make_copy, prev_keys=prev_keys) if not changed_only or len(prev_keys) > 0: return d return NoResult class UnflattenAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.unflatten`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "unflatten" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing") make_copy = asset_cls.resolve_setting(make_copy, "make_copy") changed_only = asset_cls.resolve_setting(changed_only, "changed_only") if path is not None: if isinstance(path, list): paths = [search_.resolve_pathlike_key(p) for p in path] else: paths = [search_.resolve_pathlike_key(path)] else: paths = [None] return (), { **dict( paths=paths, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, ), **kwargs, } @classmethod def call( cls, d: tp.Any, paths: tp.List[tp.PathLikeKey], skip_missing: bool = False, make_copy: bool = True, changed_only: bool = False, **kwargs, ) -> tp.Any: prev_keys = [] for p in paths: x = d if p is not None: try: x = search_.get_pathlike_key(x, p) except (KeyError, IndexError, AttributeError) as e: if not skip_missing: raise e continue x = search_.unflatten_obj(x, **kwargs) d = search_.set_pathlike_key(d, p, x, make_copy=make_copy, prev_keys=prev_keys) if not changed_only or len(prev_keys) > 0: return d return NoResult class DumpAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.dump`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "dump" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def resolve_dump_kwargs( cls, dump_engine: tp.Optional[str] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.Kwargs: """Resolve keyword arguments related to dumping.""" if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset dump_engine = asset_cls.resolve_setting(dump_engine, "dump_engine") kwargs = asset_cls.resolve_setting(kwargs, f"dump_engine_kwargs.{dump_engine}", default={}, merge=True) return {"dump_engine": dump_engine, **kwargs} @classmethod def prepare( cls, source: tp.Optional[tp.CustomTemplateLike] = None, dump_engine: tp.Optional[str] = None, template_context: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) dump_kwargs = cls.resolve_dump_kwargs(dump_engine=dump_engine, **kwargs) if source is not None: if isinstance(source, str): source = RepEval(source) elif checks.is_function(source): if checks.is_builtin_func(source): source = RepFunc(lambda _source=source: _source) else: source = RepFunc(source) elif not isinstance(source, CustomTemplate): raise TypeError(f"Source must be a string, function, or template, not {type(source)}") return (), { **dict( source=source, template_context=template_context, ), **dump_kwargs, **kwargs, } @classmethod def call( cls, d: tp.Any, source: tp.Optional[CustomTemplate] = None, dump_engine: str = "nestedtext", template_context: tp.KwargsLike = None, **kwargs, ) -> tp.Any: from vectorbtpro.utils.knowledge.chatting import StoreDocument, EmbeddedDocument, ScoredDocument if source is not None: _template_context = flat_merge_dicts( { "d": d, "x": d, **(d if isinstance(d, dict) else {}), }, template_context, ) new_d = source.substitute(_template_context, eval_id="source") if checks.is_function(new_d): new_d = new_d(d) else: new_d = d if isinstance(new_d, StoreDocument): return new_d.get_content() if isinstance(new_d, (EmbeddedDocument, ScoredDocument)): return new_d.document.get_content() return dump(new_d, dump_engine=dump_engine, **kwargs) class ToDocsAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.to_documents`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "to_docs" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, document_cls: tp.Optional[tp.Type[tp.StoreDocument]] = None, template_context: tp.Union[tp.KwargsLike, tp.CustomTemplate] = None, **document_kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset document_cls = asset_cls.resolve_setting(document_cls, "document_cls") if document_cls is None: from vectorbtpro.utils.knowledge.chatting import TextDocument document_cls = TextDocument template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True) template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context) document_kwargs = {} for k, v in document_cls.fields_dict.items(): if v.default is not MISSING: if k in document_kwargs or asset_cls.has_setting(k, sub_path="document_kwargs"): document_kwargs[k] = asset_cls.resolve_setting( document_kwargs.get(k, None), k, sub_path="document_kwargs", merge=isinstance(v.default, attr.Factory) and v.default.factory is dict, ) if k == "template_context": document_kwargs[k] = merge_dicts(template_context, document_kwargs[k]) if k == "dump_kwargs": document_kwargs[k] = DumpAssetFunc.resolve_dump_kwargs(**document_kwargs[k]) document_kwargs = substitute_templates( document_kwargs, template_context, eval_id="document_kwargs", strict=False ) return (), { **dict(document_cls=document_cls), **document_kwargs, } @classmethod def call( cls, d: tp.Any, document_cls: tp.Optional[tp.Type[tp.StoreDocument]] = None, template_context: tp.KwargsLike = None, **document_kwargs, ) -> tp.Any: if document_cls is None: from vectorbtpro.utils.knowledge.chatting import TextDocument document_cls = TextDocument _template_context = flat_merge_dicts( { "d": d, "x": d, **(d if isinstance(d, dict) else {}), }, template_context, ) document_kwargs = substitute_templates(document_kwargs, _template_context, eval_id="document_kwargs") return document_cls.from_data(d, template_context=_template_context, **document_kwargs) class SplitTextAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.split_text`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "split_text" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, text_path: tp.Optional[tp.PathLikeKey] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **split_text_kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset from vectorbtpro.utils.knowledge.chatting import resolve_text_splitter text_path = asset_cls.resolve_setting(text_path, "text_path", sub_path="document_kwargs") split_text_kwargs = asset_cls.resolve_setting( split_text_kwargs, "split_text_kwargs", sub_path="document_kwargs", merge=True ) text_splitter = split_text_kwargs.pop("text_splitter", None) text_splitter = resolve_text_splitter(text_splitter=text_splitter) if isinstance(text_splitter, type): text_splitter = text_splitter(**split_text_kwargs) elif split_text_kwargs: text_splitter = text_splitter.replace(**split_text_kwargs) return (), { **dict( text_path=text_path, text_splitter=text_splitter, ), } @classmethod def call( cls, d: tp.Any, text_path: tp.Optional[tp.PathLikeKey] = None, **split_text_kwargs, ) -> tp.Any: from vectorbtpro.utils.knowledge.chatting import TextDocument document = TextDocument("", d, text_path=text_path, split_text_kwargs=split_text_kwargs) return [document_chunk.data for document_chunk in document.split()] # ############# Reduce classes ############# # class ReduceAssetFunc(AssetFunc): """Abstract asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.reduce`.""" _wrap: tp.ClassVar[tp.Optional[str]] = False _initializer: tp.ClassVar[tp.Optional[tp.Any]] = None @classmethod def call(cls, d1: tp.Any, d2: tp.Any, *args, **kwargs) -> tp.Any: raise NotImplementedError @classmethod def prepare_and_call(cls, d1: tp.Any, d2: tp.Any, *args, **kwargs): args, kwargs = cls.prepare(*args, **kwargs) return cls.call(d1, d2, *args, **kwargs) class CollectAssetFunc(ReduceAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.collect`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "collect" _initializer: tp.ClassVar[tp.Optional[tp.Any]] = {} @classmethod def prepare( cls, sort_keys: tp.Optional[bool] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: if asset_cls is None: from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset asset_cls = KnowledgeAsset sort_keys = asset_cls.resolve_setting(sort_keys, "sort_keys") return (), {**dict(sort_keys=sort_keys), **kwargs} @classmethod def sort_key(cls, k: tp.Any) -> tuple: """Function for sorting keys.""" return (0, k) if isinstance(k, str) else (1, k) @classmethod def call(cls, d1: tp.Any, d2: tp.Any, sort_keys: bool = False) -> tp.Any: if isinstance(d1, list): d1 = {i: v for i, v in enumerate(d1)} if isinstance(d2, list): d2 = {i: v for i, v in enumerate(d2)} if not isinstance(d1, dict) or not isinstance(d2, dict): raise TypeError(f"Data items must be either dicts or lists, not {type(d1)} and {type(d2)}") new_d1 = dict(d1) for k1 in d1: if k1 not in new_d1: new_d1[k1] = [d1[k1]] if k1 in d2: new_d1[k1].append(d2[k1]) for k2 in d2: if k2 not in new_d1: new_d1[k2] = [d2[k2]] if sort_keys: return dict(sorted(new_d1.items(), key=lambda x: cls.sort_key(x[0]))) return new_d1 class MergeDictsAssetFunc(ReduceAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.merge_dicts`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "merge_dicts" _wrap: tp.ClassVar[tp.Optional[str]] = True _initializer: tp.ClassVar[tp.Optional[tp.Any]] = {} @classmethod def call(cls, d1: tp.Any, d2: tp.Any, **kwargs) -> tp.Any: if not isinstance(d1, dict) or not isinstance(d2, dict): raise TypeError(f"Data items must be dicts, not {type(d1)} and {type(d2)}") return merge_dicts(d1, d2, **kwargs) class MergeListsAssetFunc(ReduceAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.merge_lists`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "merge_lists" _wrap: tp.ClassVar[tp.Optional[str]] = True _initializer: tp.ClassVar[tp.Optional[tp.Any]] = [] @classmethod def call(cls, d1: tp.Any, d2: tp.Any) -> tp.Any: if not isinstance(d1, list) or not isinstance(d2, list): raise TypeError(f"Data items must be lists, not {type(d1)} and {type(d2)}") return d1 + d2 # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Base asset classes. See `vectorbtpro.utils.knowledge` for the toy dataset.""" import hashlib import json import re from collections.abc import MutableSequence from pathlib import Path import pandas as pd from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.config import Configured from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts from vectorbtpro.utils.decorators import hybrid_method from vectorbtpro.utils.execution import Task, execute, NoResult from vectorbtpro.utils.knowledge.chatting import RankContextable from vectorbtpro.utils.module_ import get_caller_qualname from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.path_ import dir_tree_from_paths, remove_dir, check_mkdir from vectorbtpro.utils.pbar import ProgressBar from vectorbtpro.utils.pickling import decompress, dumps, load_bytes, save, load from vectorbtpro.utils.search_ import flatten_obj, unflatten_obj from vectorbtpro.utils.template import CustomTemplate, RepEval, RepFunc from vectorbtpro.utils.warnings_ import warn __all__ = [ "AssetCacheManager", "KnowledgeAsset", ] asset_cache: tp.Dict[tp.Hashable, "KnowledgeAsset"] = {} """Asset cache.""" class AssetCacheManager(Configured): """Class for managing knowledge asset cache. For defaults, see `vectorbtpro._settings.knowledge`.""" _settings_path: tp.SettingsPath = "knowledge" _specializable: tp.ClassVar[bool] = False _extendable: tp.ClassVar[bool] = False def __init__( self, persist_cache: tp.Optional[bool] = None, cache_dir: tp.Optional[tp.PathLike] = None, cache_mkdir_kwargs: tp.KwargsLike = None, clear_cache: tp.Optional[bool] = None, max_cache_count: tp.Optional[int] = None, save_cache_kwargs: tp.KwargsLike = None, load_cache_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Configured.__init__( self, persist_cache=persist_cache, cache_dir=cache_dir, cache_mkdir_kwargs=cache_mkdir_kwargs, clear_cache=clear_cache, max_cache_count=max_cache_count, save_cache_kwargs=save_cache_kwargs, load_cache_kwargs=load_cache_kwargs, template_context=template_context, **kwargs, ) persist_cache = self.resolve_setting(persist_cache, "cache") cache_dir = self.resolve_setting(cache_dir, "asset_cache_dir") cache_mkdir_kwargs = self.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True) clear_cache = self.resolve_setting(clear_cache, "clear_cache") max_cache_count = self.resolve_setting(max_cache_count, "max_cache_count") save_cache_kwargs = self.resolve_setting(save_cache_kwargs, "save_cache_kwargs", merge=True) load_cache_kwargs = self.resolve_setting(load_cache_kwargs, "load_cache_kwargs", merge=True) template_context = self.resolve_setting(template_context, "template_context", merge=True) if isinstance(cache_dir, CustomTemplate): asset_cache_dir = cache_dir cache_dir = self.get_setting("cache_dir") if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) asset_cache_dir = asset_cache_dir.substitute(template_context, eval_id="asset_cache_dir") cache_dir = asset_cache_dir cache_dir = Path(cache_dir) if cache_dir.exists(): if clear_cache: remove_dir(cache_dir, missing_ok=True, with_contents=True) check_mkdir(cache_dir, **cache_mkdir_kwargs) self._persist_cache = persist_cache self._cache_dir = cache_dir self._max_cache_count = max_cache_count self._save_cache_kwargs = save_cache_kwargs self._load_cache_kwargs = load_cache_kwargs self._template_context = template_context @property def persist_cache(self) -> bool: """Whether to persist cache on disk.""" return self._persist_cache @property def cache_dir(self) -> tp.Path: """Cache directory.""" return self._cache_dir @property def max_cache_count(self) -> tp.Optional[int]: """Maximum number of assets to be cached. Keeps only the most recent assets.""" return self._max_cache_count @property def save_cache_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pickling.save`.""" return self._save_cache_kwargs @property def load_cache_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pickling.load`.""" return self._load_cache_kwargs @classmethod def generate_cache_key(cls, **kwargs) -> str: """Generate a cache key based on the current VBT version, settings, and keyword arguments.""" from vectorbtpro._version import __version__ bytes_ = b"" bytes_ += dumps(kwargs) bytes_ += dumps(cls.get_settings()) bytes_ += dumps(__version__) return hashlib.md5(bytes_).hexdigest() def load_asset(self, cache_key: str) -> tp.Optional[tp.MaybeKnowledgeAsset]: """Load the knowledge asset under a cache key.""" if cache_key in asset_cache: return asset_cache[cache_key] asset_cache_file = self.cache_dir / cache_key if asset_cache_file.exists(): return load(asset_cache_file, **self.load_cache_kwargs) def cleanup_cache_dir(self) -> None: """Keep only the most recent assets.""" if not self.max_cache_count: return files = [f for f in self.cache_dir.iterdir() if f.is_file()] if len(files) <= self.max_cache_count: return files.sort(key=lambda f: f.stat().st_mtime, reverse=True) files_to_delete = files[self.max_cache_count:] for file_path in files_to_delete: file_path.unlink(missing_ok=True) def save_asset(self, asset: tp.MaybeKnowledgeAsset, cache_key: str) -> tp.Optional[tp.Path]: """Save a knowledge asset under a cache key.""" asset_cache[cache_key] = asset if self.persist_cache: asset_cache_file = self.cache_dir / cache_key path = save(asset, path=asset_cache_file, **self.save_cache_kwargs) self.cleanup_cache_dir() return path KnowledgeAssetT = tp.TypeVar("KnowledgeAssetT", bound="KnowledgeAsset") class MetaKnowledgeAsset(type(Configured), type(MutableSequence)): """Metaclass for `KnowledgeAsset`.""" pass class KnowledgeAsset(RankContextable, Configured, MutableSequence, metaclass=MetaKnowledgeAsset): """Class for working with a knowledge asset. This class behaves like a mutable sequence. For defaults, see `vectorbtpro._settings.knowledge`.""" _settings_path: tp.SettingsPath = "knowledge" @hybrid_method def combine( cls_or_self: tp.MaybeType[KnowledgeAssetT], *objs: tp.MaybeTuple[KnowledgeAssetT], **kwargs, ) -> KnowledgeAssetT: """Combine multiple `KnowledgeAsset` instances into one. Usage: ```pycon >>> asset1 = asset[[0, 1]] >>> asset2 = asset[[2, 3]] >>> asset1.combine(asset2).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}] ``` """ if not isinstance(cls_or_self, type) and len(objs) == 0: if isinstance(cls_or_self[0], list): return cls_or_self.merge_lists(**kwargs) if isinstance(cls_or_self[0], dict): return cls_or_self.merge_dicts(**kwargs) raise ValueError("Cannot determine type of data items. Use merge_lists or merge_dicts.") elif not isinstance(cls_or_self, type) and len(objs) > 0: objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, KnowledgeAsset): raise TypeError("Each object to be combined must be an instance of KnowledgeAsset") new_data = [] new_single_item = True for obj in objs: new_data.extend(obj.data) if not obj.single_item: new_single_item = False kwargs = cls_or_self.resolve_merge_kwargs( *[obj.config for obj in objs], single_item=new_single_item, data=new_data, **kwargs, ) return cls(**kwargs) @hybrid_method def merge( cls_or_self: tp.MaybeType[KnowledgeAssetT], *objs: tp.MaybeTuple[KnowledgeAssetT], flatten_kwargs: tp.KwargsLike = None, **kwargs, ) -> KnowledgeAssetT: """Either merge multiple `KnowledgeAsset` instances into one if called as a class method or instance method with at least one additional object, or merge data items of a single instance if called as an instance method with no additional objects. Usage: ```pycon >>> asset1 = asset.select(["s"]) >>> asset2 = asset.select(["b", "d2"]) >>> asset1.merge(asset2).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}}] ``` """ if not isinstance(cls_or_self, type) and len(objs) == 0: if isinstance(cls_or_self[0], list): return cls_or_self.merge_lists(**kwargs) if isinstance(cls_or_self[0], dict): return cls_or_self.merge_dicts(**kwargs) raise ValueError("Cannot determine type of data items. Use merge_lists or merge_dicts.") elif not isinstance(cls_or_self, type) and len(objs) > 0: objs = (cls_or_self, *objs) cls = type(cls_or_self) else: cls = cls_or_self if len(objs) == 1: objs = objs[0] objs = list(objs) for obj in objs: if not checks.is_instance_of(obj, KnowledgeAsset): raise TypeError("Each object to be merged must be an instance of KnowledgeAsset") if flatten_kwargs is None: flatten_kwargs = {} if "annotate_all" not in flatten_kwargs: flatten_kwargs["annotate_all"] = True if "excl_types" not in flatten_kwargs: flatten_kwargs["excl_types"] = (tuple, set, frozenset) max_items = 1 new_single_item = True for obj in objs: obj_data = obj.data if len(obj_data) > max_items: max_items = len(obj_data) if not obj.single_item: new_single_item = False flat_data = [] for obj in objs: obj_data = obj.data if len(obj_data) == 1: obj_data = [obj_data] * max_items flat_obj_data = list(map(lambda x: flatten_obj(x, **flatten_kwargs), obj_data)) flat_data.append(flat_obj_data) new_data = [] for flat_dcts in zip(*flat_data): merged_flat_dct = flat_merge_dicts(*flat_dcts) new_data.append(unflatten_obj(merged_flat_dct)) kwargs = cls_or_self.resolve_merge_kwargs( *[obj.config for obj in objs], single_item=new_single_item, data=new_data, **kwargs, ) return cls(**kwargs) @classmethod def from_json_file( cls: tp.Type[KnowledgeAssetT], path: tp.PathLike, compression: tp.Union[None, bool, str] = None, decompress_kwargs: tp.KwargsLike = None, **kwargs, ) -> KnowledgeAssetT: """Build `KnowledgeAsset` from a JSON file.""" bytes_ = load_bytes(path, compression=compression, decompress_kwargs=decompress_kwargs) json_str = bytes_.decode("utf-8") return cls(data=json.loads(json_str), **kwargs) @classmethod def from_json_bytes( cls: tp.Type[KnowledgeAssetT], bytes_: bytes, compression: tp.Union[None, bool, str] = None, decompress_kwargs: tp.KwargsLike = None, **kwargs, ) -> KnowledgeAssetT: """Build `KnowledgeAsset` from JSON bytes.""" if decompress_kwargs is None: decompress_kwargs = {} bytes_ = decompress(bytes_, compression=compression, **decompress_kwargs) json_str = bytes_.decode("utf-8") return cls(data=json.loads(json_str), **kwargs) def __init__(self, data: tp.Optional[tp.List[tp.Any]] = None, single_item: bool = True, **kwargs) -> None: if data is None: data = [] if not isinstance(data, list): data = [data] else: data = list(data) if len(data) > 1: single_item = False Configured.__init__( self, data=data, single_item=single_item, **kwargs, ) self._data = data self._single_item = single_item @property def data(self) -> tp.List[tp.Any]: """Data.""" return self._data @property def single_item(self) -> bool: """Whether this instance holds a single item.""" return self._single_item def modify_data(self, data: tp.List[tp.Any]) -> None: """Modify data in place.""" if len(data) > 1: single_item = False else: single_item = self.single_item self._data = data self._single_item = single_item self.update_config(data=data, single_item=single_item) # ############# Item methods ############# # def get_items(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]]) -> tp.Any: """Get one or more data items.""" if checks.is_complex_iterable(index): if all(checks.is_bool(i) for i in index): index = list(index) if len(index) != len(self.data): raise IndexError("Boolean index must have the same length as data") return self.replace(data=[item for item, flag in zip(self.data, index) if flag]) if all(checks.is_int(i) for i in index): return self.replace(data=[self.data[i] for i in index]) raise TypeError("Index must contain all integers or all booleans") if isinstance(index, slice): return self.replace(data=self.data[index]) return self.data[index] def set_items( self: KnowledgeAssetT, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]], value: tp.Any, inplace: bool = False, ) -> tp.Optional[KnowledgeAssetT]: """Set one or more data items. Returns a new `KnowledgeAsset` instance if `inplace` is False.""" new_data = list(self.data) if checks.is_complex_iterable(index): index = list(index) if all(checks.is_bool(i) for i in index): if len(index) != len(new_data): raise IndexError("Boolean index must have the same length as data") if checks.is_complex_iterable(value): value = list(value) if len(value) == len(index): for i, (b, v) in enumerate(zip(index, value)): if b: new_data[i] = v else: num_true = sum(index) if len(value) != num_true: raise ValueError(f"Attempting to assign {len(value)} values to {num_true} targets") it = iter(value) for i, b in enumerate(index): if b: new_data[i] = next(it) else: for i, b in enumerate(index): if b: new_data[i] = value elif all(checks.is_int(i) for i in index): if checks.is_complex_iterable(value): value = list(value) if len(value) != len(index): raise ValueError(f"Attempting to assign {len(value)} values to {len(index)} targets") for i, v in zip(index, value): new_data[i] = v else: for i in index: new_data[i] = value else: raise TypeError("Index must contain all integers or all booleans") else: new_data[index] = value if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def delete_items( self: KnowledgeAssetT, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]], inplace: bool = False, ) -> tp.Optional[KnowledgeAssetT]: """Delete one or more data items. Returns a new `KnowledgeAsset` instance if `inplace` is False.""" new_data = list(self.data) if checks.is_complex_iterable(index): if all(checks.is_bool(i) for i in index): index = list(index) if len(index) != len(new_data): raise IndexError("Boolean index must have the same length as data") new_data = [item for item, flag in zip(new_data, index) if not flag] elif all(checks.is_int(i) for i in index): indices_to_remove = set(index) max_index = len(new_data) - 1 for i in indices_to_remove: if not -len(new_data) <= i <= max_index: raise IndexError(f"Index {i} out of range") new_data = [item for i, item in enumerate(new_data) if i not in indices_to_remove] else: raise TypeError("Index must contain all integers or all booleans") else: del new_data[index] if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def append_item( self: KnowledgeAssetT, d: tp.Any, inplace: bool = False, ) -> tp.Optional[KnowledgeAssetT]: """Append a new data item. Returns a new `KnowledgeAsset` instance if `inplace` is False.""" new_data = list(self.data) new_data.append(d) if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def extend_items( self: KnowledgeAssetT, data: tp.Iterable[tp.Any], inplace: bool = False, ) -> tp.Optional[KnowledgeAssetT]: """Extend by new data items. Returns a new `KnowledgeAsset` instance if `inplace` is False.""" new_data = list(self.data) new_data.extend(data) if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def remove_empty(self, inplace: bool = False) -> tp.Optional[KnowledgeAssetT]: """Remove empty data items.""" from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc new_data = [d for d in self.data if not FindRemoveAssetFunc.is_empty_func(d)] if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def unique( self: KnowledgeAssetT, *args, keep: str = "first", inplace: bool = False, **kwargs, ) -> tp.Optional[KnowledgeAssetT]: """De-duplicate based on `KnowledgeAsset.get` called on `*args` and `**kwargs`. Returns a new `KnowledgeAsset` instance if `inplace` is False. Usage: ```pycon >>> asset.unique("b").get() [{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}] ``` """ keys = self.get(*args, **kwargs) if keep.lower() == "first": seen = set() new_data = [] for key, item in zip(keys, self.data): if key not in seen: seen.add(key) new_data.append(item) elif keep.lower() == "last": seen = set() new_data_reversed = [] for key, item in zip(reversed(keys), reversed(self.data)): if key not in seen: seen.add(key) new_data_reversed.append(item) new_data = list(reversed(new_data_reversed)) else: raise ValueError(f"Invalid keep option: '{keep}'") if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def sort( self: KnowledgeAssetT, *args, keys: tp.Optional[tp.Iterable[tp.Key]] = None, ascending: bool = True, inplace: bool = False, **kwargs, ) -> tp.Optional[KnowledgeAssetT]: """Sort based on `KnowledgeAsset.get` called on `*args` and `**kwargs`. Returns a new `KnowledgeAsset` instance if `inplace` is False. Usage: ```pycon >>> asset.sort("d2.c").get() [{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}] ``` """ if keys is None: keys = self.get(*args, **kwargs) new_data = [x for _, x in sorted(zip(keys, self.data), key=lambda x: x[0], reverse=not ascending)] if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def shuffle( self: KnowledgeAssetT, seed: tp.Optional[int] = None, inplace: bool = False, ) -> tp.Optional[KnowledgeAssetT]: """Shuffle data items.""" import random if seed is not None: random.seed(seed) new_data = list(self.data) random.shuffle(new_data) if inplace: self.modify_data(new_data) return None return self.replace(data=new_data) def sample( self, k: tp.Optional[int] = None, seed: tp.Optional[int] = None, wrap: bool = True, ) -> tp.Any: """Pick a random sample of data items.""" import random if k is None: k = 1 single_item = True else: single_item = False if seed is not None: random.seed(seed) new_data = random.sample(self.data, min(len(self.data), k)) if wrap: return self.replace(data=new_data, single_item=single_item) if single_item: return new_data[0] return new_data def print_sample(self, k: tp.Optional[int] = None, seed: tp.Optional[int] = None, **kwargs) -> None: """Print a random sample. Keyword arguments are passed to `KnowledgeAsset.print`.""" self.sample(k=k, seed=seed).print(**kwargs) # ############# Collection methods ############# # def __len__(self) -> int: return len(self.data) # ############# Sequence methods ############# # def __getitem__(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]]) -> tp.Any: return self.get_items(index) # ############# MutableSequence methods ############# # def insert(self, index: int, value: tp.Any) -> None: new_data = list(self.data) new_data.insert(index, value) self.modify_data(new_data) def __setitem__(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]], value: tp.Any) -> None: self.set_items(index, value, inplace=True) def __delitem__(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]]) -> None: self.delete_items(index, inplace=True) def __add__(self: KnowledgeAssetT, other: tp.Any) -> KnowledgeAssetT: if not isinstance(other, KnowledgeAsset): other = KnowledgeAsset(other) mro_self = self.__class__.mro() mro_other = other.__class__.mro() common_bases = set(mro_self).intersection(mro_other) for cls in mro_self: if cls in common_bases: new_type = cls break else: new_type = KnowledgeAsset return new_type.combine(self, other) def __iadd__(self: KnowledgeAssetT, other: tp.Any) -> KnowledgeAssetT: if isinstance(other, KnowledgeAsset): other = other.data self.extend_items(other, inplace=True) return self # ############# Apply methods ############# # def apply( self, func: tp.MaybeList[tp.Union[tp.AssetFuncLike, tp.AssetPipeline]], *args, execute_kwargs: tp.KwargsLike = None, wrap: tp.Optional[bool] = None, single_item: tp.Optional[bool] = None, return_iterator: bool = False, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Apply a function to each data item. Function can be either a callable, a tuple of function and its arguments, a `vectorbtpro.utils.execution.Task` instance, a subclass of `vectorbtpro.utils.knowledge.base_asset_funcs.AssetFunc` or its prefix or full name. Moreover, function can be a list of the above. In such a case, `BasicAssetPipeline` will be used. If function is a valid expression, `ComplexAssetPipeline` will be used. Uses `vectorbtpro.utils.execution.execute` for execution. If `wrap` is True, returns a new `KnowledgeAsset` instance, otherwise raw output. Usage: ```pycon >>> asset.apply(["flatten", ("query", len)]) [5, 5, 5, 5, 6] >>> asset.apply("query(flatten(d), len)") [5, 5, 5, 5, 6] ``` """ from vectorbtpro.utils.knowledge.asset_pipelines import AssetPipeline, BasicAssetPipeline, ComplexAssetPipeline execute_kwargs = self.resolve_setting(execute_kwargs, "execute_kwargs", merge=True) asset_func_meta = {} if isinstance(func, list): func, args, kwargs = ( BasicAssetPipeline( func, *args, cond_kwargs=dict(asset_cls=type(self)), asset_func_meta=asset_func_meta, **kwargs, ), (), {}, ) elif isinstance(func, str) and not func.isidentifier(): if len(args) > 0: raise ValueError("No more positional arguments can be applied to ComplexAssetPipeline") func, args, kwargs = ( ComplexAssetPipeline( func, context=kwargs.get("template_context", None), cond_kwargs=dict(asset_cls=type(self)), asset_func_meta=asset_func_meta, **kwargs, ), (), {}, ) elif not isinstance(func, AssetPipeline): func, args, kwargs = AssetPipeline.resolve_task( func, *args, cond_kwargs=dict(asset_cls=type(self)), asset_func_meta=asset_func_meta, **kwargs, ) else: if len(args) > 0: raise ValueError("No more positional arguments can be applied to AssetPipeline") if len(kwargs) > 0: raise ValueError("No more keyword arguments can be applied to AssetPipeline") prefix = get_caller_qualname().split(".")[-1] if "_short_name" in asset_func_meta: prefix += f"[{asset_func_meta['_short_name']}]" elif isinstance(func, type): prefix += f"[{func.__name__}]" else: prefix += f"[{type(func).__name__}]" execute_kwargs = merge_dicts( dict( show_progress=False if self.single_item else None, pbar_kwargs=dict( bar_id=get_caller_qualname(), prefix=prefix, ), ), execute_kwargs, ) def _get_task_generator(): for i, d in enumerate(self.data): _kwargs = dict(kwargs) if "template_context" in _kwargs: _kwargs["template_context"] = flat_merge_dicts( {"i": i}, _kwargs["template_context"], ) if return_iterator: yield Task(func, d, *args, **_kwargs).execute() else: yield Task(func, d, *args, **_kwargs) tasks = _get_task_generator() if return_iterator: return tasks new_data = execute(tasks, size=len(self.data), **execute_kwargs) if new_data is NoResult: new_data = [] if wrap is None and asset_func_meta.get("_wrap", None) is not None: wrap = asset_func_meta["_wrap"] if wrap is None: wrap = True if single_item is None: single_item = self.single_item if wrap: return self.replace(data=new_data, single_item=single_item) if single_item: if len(new_data) == 1: return new_data[0] if len(new_data) == 0: return None return new_data def get( self: KnowledgeAssetT, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, source: tp.Optional[tp.CustomTemplateLike] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Get data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.GetAssetFunc`. Use argument `path` to specify what part of the data item should be got. For example, "x.y[0].z" to navigate nested dictionaries/lists. If `keep_path` is True, the data item will be represented as a nested dictionary with path as keys. If multiple paths are provided, `keep_path` automatically becomes True, and they will be merged into one nested dictionary. If `skip_missing` is True and path is missing in the data item, will skip the data item. Use argument `source` instead of `path` or in addition to `path` to also preprocess the source. It can be a string or function (will become a template), or any custom template. In this template, the index of the data item is represented by "i", the data item itself is represented by "d", the data item under the path is represented by "x" while its fields are represented by their names. Usage: ```pycon >>> asset.get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] >>> asset.get("d2.l[0]") [1, 3, 5, 7, 9] >>> asset.get("d2.l", source=lambda x: sum(x)) [3, 7, 11, 15, 19] >>> asset.get("d2.l[0]", keep_path=True) [{'d2': {'l': {0: 1}}}, {'d2': {'l': {0: 3}}}, {'d2': {'l': {0: 5}}}, {'d2': {'l': {0: 7}}}, {'d2': {'l': {0: 9}}}] >>> asset.get(["d2.l[0]", "d2.l[1]"]) [{'d2': {'l': {0: 1, 1: 2}}}, {'d2': {'l': {0: 3, 1: 4}}}, {'d2': {'l': {0: 5, 1: 6}}}, {'d2': {'l': {0: 7, 1: 8}}}, {'d2': {'l': {0: 9, 1: 10}}}] >>> asset.get("xyz", skip_missing=True) [123] ``` """ if path is None and source is None: if self.single_item: if len(self.data) == 1: return self.data[0] if len(self.data) == 0: return None return self.data return self.apply( "get", path=path, keep_path=keep_path, skip_missing=skip_missing, source=source, template_context=template_context, **kwargs, ) def select(self: KnowledgeAssetT, *args, **kwargs) -> KnowledgeAssetT: """Call `KnowledgeAsset.get` and return a new `KnowledgeAsset` instance.""" return self.get(*args, wrap=True, **kwargs) def set( self: KnowledgeAssetT, value: tp.Any, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Set data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.SetAssetFunc`. Argument `value` can be any value, function (will become a template), or a template. In this template, the index of the data item is represented by "i", the data item itself is represented by "d", the data item under the path is represented by "x" while its fields are represented by their names. Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z" to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and path is missing in the data item, will skip the data item. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Usage: ```pycon >>> asset.set(lambda d: sum(d["d2"]["l"])).get() [3, 7, 11, 15, 19] >>> asset.set(lambda d: sum(d["d2"]["l"]), path="d2.sum").get() >>> asset.set(lambda x: sum(x["l"]), path="d2.sum").get() >>> asset.set(lambda l: sum(l), path="d2.sum").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2], 'sum': 3}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4], 'sum': 7}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6], 'sum': 11}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8], 'sum': 15}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10], 'sum': 19}, 'xyz': 123}] >>> asset.set(lambda l: sum(l), path="d2.l").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': 3}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': 7}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': 11}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': 15}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': 19}, 'xyz': 123}] ``` """ return self.apply( "set", value=value, path=path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, template_context=template_context, **kwargs, ) def remove( self: KnowledgeAssetT, path: tp.MaybeList[tp.PathLikeKey], skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Remove data items or parts of them. If `path` is an integer, removes the entire data item at that index. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.RemoveAssetFunc`. Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z" to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and path is missing in the data item, will skip the data item. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Usage: ```pycon >>> asset.remove("d2.l[0]").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [10]}, 'xyz': 123}] >>> asset.remove("xyz", skip_missing=True).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}}] ``` """ return self.apply( "remove", path=path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def move( self: KnowledgeAssetT, path: tp.Union[tp.PathMoveDict, tp.MaybeList[tp.PathLikeKey]], new_path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Move data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.MoveAssetFunc`. Use argument `path` to specify what part of the data item should be renamed. For example, "x.y[0].z" to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and path is missing in the data item, will skip the data item. Use argument `new_path` to specify the last part of the data item (i.e., token) that should be renamed to. Multiple tokens can be provided. If None, `path` must be a dictionary. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Usage: ```pycon >>> asset.move("d2.l", "l").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red'}, 'l': [1, 2]}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue'}, 'l': [3, 4]}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green'}, 'l': [5, 6]}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow'}, 'l': [7, 8]}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black'}, 'xyz': 123, 'l': [9, 10]}] >>> asset.move({"d2.c": "c", "b": "d2.b"}).get() >>> asset.move(["d2.c", "b"], ["c", "d2.b"]).get() [{'s': 'ABC', 'd2': {'l': [1, 2], 'b': True}, 'c': 'red'}, {'s': 'BCD', 'd2': {'l': [3, 4], 'b': True}, 'c': 'blue'}, {'s': 'CDE', 'd2': {'l': [5, 6], 'b': False}, 'c': 'green'}, {'s': 'DEF', 'd2': {'l': [7, 8], 'b': False}, 'c': 'yellow'}, {'s': 'EFG', 'd2': {'l': [9, 10], 'b': False}, 'xyz': 123, 'c': 'black'}] ``` """ return self.apply( "move", path=path, new_path=new_path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def rename( self: KnowledgeAssetT, path: tp.Union[tp.PathRenameDict, tp.MaybeList[tp.PathLikeKey]], new_token: tp.Optional[tp.MaybeList[tp.PathKeyToken]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Rename data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.RenameAssetFunc`. Same as `KnowledgeAsset.move` but must specify new token instead of new path. Usage: ```pycon >>> asset.rename("d2.l", "x").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'x': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'x': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'x': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'x': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'x': [9, 10]}, 'xyz': 123}] >>> asset.rename("xyz", "zyx", skip_missing=True, changed_only=True).get() [{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'zyx': 123}] ``` """ return self.apply( "rename", path=path, new_token=new_token, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def reorder( self: KnowledgeAssetT, new_order: tp.Union[str, tp.PathKeyTokens], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Reorder data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.ReorderAssetFunc`. Can change order in dicts based on `vectorbtpro.utils.config.reorder_dict` and sequences based on `vectorbtpro.utils.config.reorder_list`. Argument `new_order` can be a sequence of tokens. To not reorder a subset of keys, they can be replaced by an ellipsis (`...`). For example, `["a", ..., "z"]` puts the token "a" at the start and the token "z" at the end while other tokens are left in the original order. If `new_order` is a string, it can be "asc"/"ascending" or "desc"/"descending". Other than that, it can be a string or function (will become a template), or any custom template. In this template, the data item is the index of the data item is represented by "i", the data item itself is represented by "d", the data item under the path is represented by "x" while its fields are represented by their names. Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z" to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and path is missing in the data item, will skip the data item. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Usage: ```pycon >>> asset.reorder(["xyz", ...], skip_missing=True).get() >>> asset.reorder(lambda x: ["xyz", ...] if "xyz" in x else [...]).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'xyz': 123, 's': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}}] >>> asset.reorder("descending", path="d2.l").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [2, 1]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [4, 3]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [6, 5]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [8, 7]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [10, 9]}, 'xyz': 123}] ``` """ return self.apply( "reorder", new_order=new_order, path=path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, template_context=template_context, **kwargs, ) def query( self: KnowledgeAssetT, expression: tp.CustomTemplateLike, query_engine: tp.Optional[str] = None, template_context: tp.KwargsLike = None, return_type: tp.Optional[str] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Query using an engine and return the queried data item(s). Following engines are supported: * "jmespath": Evaluation with `jmespath` package * "jsonpath", "jsonpath-ng" or "jsonpath_ng": Evaluation with `jsonpath-ng` package * "jsonpath.ext", "jsonpath-ng.ext" or "jsonpath_ng.ext": Evaluation with extended `jsonpath-ng` package * None or "template": Evaluation of each data item as a template. The index of the data item is represented by "i", the data item itself is represented by "d", the data item under the path is represented by "x" while its fields are represented by their names. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.QueryAssetFunc`. * "pandas": Same as above but variables being columns If `return_type` is "item", returns the data item when matched. If `return_type` is "bool", returns True when matched. Templates can also use the functions defined in `vectorbtpro.utils.search_.search_config`. They work on single values and sequences alike. Keyword arguments are passed to the respective search/parse/evaluation function. Usage: ```pycon >>> asset.query("d['s'] == 'ABC'") >>> asset.query("x['s'] == 'ABC'") >>> asset.query("s == 'ABC'") [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}] >>> asset.query("x['s'] == 'ABC'", return_type="bool") [True, False, False, False, False] >>> asset.query("find('BC', s)") >>> asset.query(lambda s: "BC" in s) [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}] >>> asset.query("[?contains(s, 'BC')].s", query_engine="jmespath") ['ABC', 'BCD'] >>> asset.query("[].d2.c", query_engine="jmespath") ['red', 'blue', 'green', 'yellow', 'black'] >>> asset.query("[?d2.c != `blue`].d2.l", query_engine="jmespath") [[1, 2], [5, 6], [7, 8], [9, 10]] >>> asset.query("$[*].d2.c", query_engine="jsonpath.ext") ['red', 'blue', 'green', 'yellow', 'black'] >>> asset.query("$[?(@.b == true)].s", query_engine="jsonpath.ext") ['ABC', 'BCD'] >>> asset.query("s[b]", query_engine="pandas") ['ABC', 'BCD'] ``` """ query_engine = self.resolve_setting(query_engine, "query_engine") template_context = self.resolve_setting(template_context, "template_context", merge=True) return_type = self.resolve_setting(return_type, "return_type") if query_engine is None or query_engine.lower() == "template": new_obj = self.apply( "query", expression=expression, template_context=template_context, return_type=return_type, **kwargs, ) elif query_engine.lower() == "jmespath": from vectorbtpro.utils.module_ import assert_can_import assert_can_import("jmespath") import jmespath new_obj = jmespath.search(expression, self.data, **kwargs) elif query_engine.lower() in ("jsonpath", "jsonpath-ng", "jsonpath_ng"): from vectorbtpro.utils.module_ import assert_can_import assert_can_import("jsonpath_ng") import jsonpath_ng jsonpath_expr = jsonpath_ng.parse(expression) new_obj = [match.value for match in jsonpath_expr.find(self.data, **kwargs)] elif query_engine.lower() in ("jsonpath.ext", "jsonpath-ng.ext", "jsonpath_ng.ext"): from vectorbtpro.utils.module_ import assert_can_import assert_can_import("jsonpath_ng") import jsonpath_ng.ext jsonpath_expr = jsonpath_ng.ext.parse(expression) new_obj = [match.value for match in jsonpath_expr.find(self.data, **kwargs)] elif query_engine.lower() == "pandas": if isinstance(expression, str): expression = RepEval(expression) elif checks.is_function(expression): if checks.is_builtin_func(expression): expression = RepFunc(lambda _expression=expression: _expression) else: expression = RepFunc(expression) elif not isinstance(expression, CustomTemplate): raise TypeError(f"Expression must be a template") df = pd.DataFrame.from_records(self.data) _template_context = flat_merge_dicts( { "d": df, "x": df, **df.to_dict(orient="series"), }, template_context, ) result = expression.substitute(_template_context, eval_id="expression", **kwargs) if checks.is_function(result): result = result(df) if return_type.lower() == "item": as_filter = True elif return_type.lower() == "bool": as_filter = False else: raise ValueError(f"Invalid return type: '{return_type}'") if as_filter and isinstance(result, pd.Series) and result.dtype == "bool": result = df[result] if isinstance(result, pd.Series): new_obj = result.tolist() elif isinstance(result, pd.DataFrame): new_obj = result.to_dict(orient="records") else: new_obj = result else: raise ValueError(f"Invalid query engine: '{query_engine}'") return new_obj def filter(self: KnowledgeAssetT, *args, **kwargs) -> KnowledgeAssetT: """Call `KnowledgeAsset.query` and return a new `KnowledgeAsset` instance.""" return self.query(*args, wrap=True, **kwargs) def find( self: KnowledgeAssetT, target: tp.MaybeList[tp.Any], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, per_path: tp.Optional[bool] = None, find_all: tp.Optional[bool] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, source: tp.Optional[tp.CustomTemplateLike] = None, in_dumps: tp.Optional[bool] = None, dump_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, return_type: tp.Optional[str] = None, return_path: tp.Optional[bool] = None, merge_matches: tp.Optional[bool] = None, merge_fields: tp.Optional[bool] = None, unique_matches: tp.Optional[bool] = None, unique_fields: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Find occurrences and return a new `KnowledgeAsset` instance. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FindAssetFunc`. Uses `vectorbtpro.utils.search_.contains_in_obj` (keyword arguments are passed here) to find any occurrences in each data item if `return_type` is "item" (returns the data item when matched), `return_type` is "field" (returns the field), or `return_type` is "bool" (returns True when matched). For all other return types, uses `vectorbtpro.utils.search_.find_in_obj` and `vectorbtpro.utils.search_.find`. Target can be one or multiple data items. If there are multiple targets and `find_all` is True, the match function will return True only if all targets have been found. Use argument `path` to specify what part of the data item should be searched. For example, "x.y[0].z" to navigate nested dictionaries/lists. If `keep_path` is True, the data item will be represented as a nested dictionary with path as keys. If multiple paths are provided, `keep_path` automatically becomes True, and they will be merged into one nested dictionary. If `skip_missing` is True and path is missing in the data item, will skip the data item. If `per_path` is True, will consider targets to be provided per path. Use argument `source` instead of `path` or in addition to `path` to also preprocess the source. It can be a string or function (will become a template), or any custom template. In this template, the index of the data item is represented by "i", the data item itself is represented by "d", the data item under the path is represented by "x" while its fields are represented by their names. Set `in_dumps` to True to convert the entire data item to string and search in that string. Will use `vectorbtpro.utils.formatting.dump` with `**dump_kwargs`. Disable `merge_matches` and `merge_fields` to keep empty lists when searching for matches and fields respectively. Disable `unique_matches` and `unique_fields` to keep duplicate matches and fields respectively. Usage: ```pycon >>> asset.find("BC").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}] >>> asset.find("BC", return_type="bool").get() [True, True, False, False, False] >>> asset.find(vbt.Not("BC")).get() [{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] >>> asset.find("bc", ignore_case=True).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}] >>> asset.find("bl", path="d2.c").get() [{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] >>> asset.find(5, path="d2.l[0]").get() [{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}] >>> asset.find(True, path="d2.l", source=lambda x: sum(x) >= 10).get() [{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] >>> asset.find(["A", "B", "C"]).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}] >>> asset.find(["A", "B", "C"], find_all=True).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}] >>> asset.find(r"[ABC]+", mode="regex").get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}] >>> asset.find("yenlow", mode="fuzzy").get() [{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}] >>> asset.find("yenlow", mode="fuzzy", return_type="match").get() 'yellow' >>> asset.find("yenlow", mode="fuzzy", return_type="match", merge_matches=False).get() [[], [], [], ['yellow'], []] >>> asset.find("yenlow", mode="fuzzy", return_type="match", return_path=True).get() [{}, {}, {}, {('d2', 'c'): ['yellow']}, {}] >>> asset.find("xyz", in_dumps=True).get() [{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] ``` """ found_asset = self.apply( "find", target=target, path=path, per_path=per_path, find_all=find_all, keep_path=keep_path, skip_missing=skip_missing, source=source, in_dumps=in_dumps, dump_kwargs=dump_kwargs, template_context=template_context, return_type=return_type, return_path=return_path, **kwargs, ) return_type = self.resolve_setting(return_type, "return_type") merge_matches = self.resolve_setting(merge_matches, "merge_matches") merge_fields = self.resolve_setting(merge_fields, "merge_fields") unique_matches = self.resolve_setting(unique_matches, "unique_matches") unique_fields = self.resolve_setting(unique_fields, "unique_fields") if ( ((merge_matches and return_type.lower() == "match") or (merge_fields and return_type.lower() == "field")) and isinstance(found_asset, KnowledgeAsset) and len(found_asset) > 0 and isinstance(found_asset[0], list) ): found_asset = found_asset.merge() if ( ((unique_matches and return_type.lower() == "match") or (unique_fields and return_type.lower() == "field")) and isinstance(found_asset, KnowledgeAsset) and len(found_asset) > 0 and isinstance(found_asset[0], str) ): found_asset = found_asset.unique() return found_asset def find_code( self, target: tp.Optional[tp.MaybeIterable[tp.Any]] = None, language: tp.Union[None, bool, tp.MaybeIterable[str]] = None, in_blocks: tp.Optional[bool] = None, escape_target: bool = True, escape_language: bool = True, return_type: tp.Optional[str] = "match", flags: int = 0, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Find code using `KnowledgeAsset.find`. For defaults, see `code` in `vectorbtpro._settings.knowledge`.""" language = self.resolve_setting(language, "language", sub_path="code") in_blocks = self.resolve_setting(in_blocks, "in_blocks", sub_path="code") if target is not None: if not isinstance(target, (str, list)): target = list(target) if language is not None: if not isinstance(language, (str, list)): language = list(language) if escape_language: if isinstance(language, list): language = list(map(re.escape, language)) else: language = re.escape(language) if isinstance(language, list): language = rf"(?:{'|'.join(language)})" opt_language = r"[\w+-]+" opt_title = r"(?:\s+[^\n`]+)?" if target is not None: if not isinstance(target, list): targets = [target] single_target = True else: targets = target single_target = False new_target = [] for t in targets: if escape_target: t = re.escape(t) if in_blocks: if language is not None and not isinstance(language, bool): new_t = rf""" ```{language}{opt_title}\n (?:(?!```)[\s\S])*? {t} (?:(?!```)[\s\S])*? ```\s*$ """ elif language is not None and isinstance(language, bool) and language: new_t = rf""" ```{opt_language}{opt_title}\n (?:(?!```)[\s\S])*? {t} (?:(?!```)[\s\S])*? ```\s*$ """ else: new_t = rf""" ```(?:{opt_language}{opt_title})?\n (?:(?!```)[\s\S])*? {t} (?:(?!```)[\s\S])*? ```\s*$ """ else: new_t = rf"(? tp.MaybeKnowledgeAsset: """Find and replace occurrences and return a new `KnowledgeAsset` instance. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FindReplaceAssetFunc`. Uses `vectorbtpro.utils.search_.find_in_obj` (keyword arguments are passed here) to find occurrences in each data item. Then, uses `vectorbtpro.utils.search_.replace_in_obj` to replace them. Target can be one or multiple of data items, either as a list or a dictionary. If there are multiple targets and `find_all` is True, the match function will return True only if all targets have been found. Use argument `path` to specify what part of the data item should be searched. For example, "x.y[0].z" to navigate nested dictionaries/lists. If `keep_path` is True, the data item will be represented as a nested dictionary with path as keys. If multiple paths are provided, `keep_path` automatically becomes True, and they will be merged into one nested dictionary. If `skip_missing` is True and path is missing in the data item, will skip the data item. If `per_path` is True, will consider targets and replacements to be provided per path. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Usage: ```pycon >>> asset.find_replace("BC", "XY").get() [{'s': 'AXY', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'XYD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] >>> asset.find_replace("BC", "XY", changed_only=True).get() [{'s': 'AXY', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'XYD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}] >>> asset.find_replace(r"(D)E(F)", r"\1X\2", mode="regex", changed_only=True).get() [{'s': 'DXF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}] >>> asset.find_replace(True, False, changed_only=True).get() [{'s': 'ABC', 'b': False, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': False, 'd2': {'c': 'blue', 'l': [3, 4]}}] >>> asset.find_replace(3, 30, path="d2.l", changed_only=True).get() [{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [30, 4]}}] >>> asset.find_replace({1: 10, 4: 40}, path="d2.l", changed_only=True).get() >>> asset.find_replace({1: 10, 4: 40}, path=["d2.l[0]", "d2.l[1]"], changed_only=True).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [10, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 40]}}] >>> asset.find_replace({1: 10, 4: 40}, find_all=True, changed_only=True).get() [] >>> asset.find_replace({1: 10, 2: 20}, find_all=True, changed_only=True).get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [10, 20]}}] >>> asset.find_replace("a", "X", path=["s", "d2.c"], ignore_case=True, changed_only=True).get() [{'s': 'XBC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'blXck', 'l': [9, 10]}, 'xyz': 123}] >>> asset.find_replace(123, 456, path="xyz", skip_missing=True, changed_only=True).get() [{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 456}] ``` """ return self.apply( "find_replace", target=target, replacement=replacement, path=path, per_path=per_path, find_all=find_all, keep_path=keep_path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def find_remove( self: KnowledgeAssetT, target: tp.Union[dict, tp.MaybeList[tp.Any]], path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, per_path: tp.Optional[bool] = None, find_all: tp.Optional[bool] = None, keep_path: tp.Optional[bool] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Find and remove occurrences and return a new `KnowledgeAsset` instance. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FindRemoveAssetFunc`. Similar to `KnowledgeAsset.find_replace`.""" return self.apply( "find_remove", target=target, path=path, per_path=per_path, find_all=find_all, keep_path=keep_path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def find_remove_empty(self: KnowledgeAssetT, **kwargs) -> tp.MaybeKnowledgeAsset: """Find and remove empty objects.""" from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc return self.find_remove(FindRemoveAssetFunc.is_empty_func, **kwargs) def flatten( self: KnowledgeAssetT, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Flatten data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FlattenAssetFunc`. Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z" to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and path is missing in the data item, will skip the data item. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Keyword arguments are passed to `vectorbtpro.utils.search_.flatten_obj`. Usage: ```pycon >>> asset.flatten().get() [{'s': 'ABC', 'b': True, ('d2', 'c'): 'red', ('d2', 'l', 0): 1, ('d2', 'l', 1): 2}, ... {'s': 'EFG', 'b': False, ('d2', 'c'): 'black', ('d2', 'l', 0): 9, ('d2', 'l', 1): 10, 'xyz': 123}] ``` """ return self.apply( "flatten", path=path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def unflatten( self: KnowledgeAssetT, path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, skip_missing: tp.Optional[bool] = None, make_copy: tp.Optional[bool] = None, changed_only: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Unflatten data items or parts of them. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.UnflattenAssetFunc`. Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z" to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and path is missing in the data item, will skip the data item. Set `make_copy` to True to not modify original data. Set `changed_only` to True to keep only the data items that have been changed. Keyword arguments are passed to `vectorbtpro.utils.search_.unflatten_obj`. Usage: ```pycon >>> asset.flatten().unflatten().get() [{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}, {'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}, {'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] ``` """ return self.apply( "unflatten", path=path, skip_missing=skip_missing, make_copy=make_copy, changed_only=changed_only, **kwargs, ) def dump( self: KnowledgeAssetT, source: tp.Optional[tp.CustomTemplateLike] = None, dump_engine: tp.Optional[str] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Dump data items. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.DumpAssetFunc`. Following engines are supported: * "repr": Dumping with `repr` * "prettify": Dumping with `vectorbtpro.utils.formatting.prettify` * "nestedtext": Dumping with NestedText (https://pypi.org/project/nestedtext/) * "yaml": Dumping with YAML * "toml": Dumping with TOML (https://pypi.org/project/toml/) * "json": Dumping with JSON Use argument `source` to also preprocess the source. It can be a string or function (will become a template), or any custom template. In this template, the index of the data item is represented by "i", the data item itself is represented by "d" while its fields are represented by their names. Keyword arguments are passed to the respective engine. Usage: ```pycon >>> print(asset.dump(source="{i: d}", default_flow_style=True).join()) {0: {s: ABC, b: true, d2: {c: red, l: [1, 2]}}} {1: {s: BCD, b: true, d2: {c: blue, l: [3, 4]}}} {2: {s: CDE, b: false, d2: {c: green, l: [5, 6]}}} {3: {s: DEF, b: false, d2: {c: yellow, l: [7, 8]}}} {4: {s: EFG, b: false, d2: {c: black, l: [9, 10]}, xyz: 123}} ``` """ return self.apply( "dump", source=source, dump_engine=dump_engine, template_context=template_context, **kwargs, ) def dump_all( self, source: tp.Optional[tp.CustomTemplateLike] = None, dump_engine: tp.Optional[str] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> str: """Dump data list as a single data item. See `KnowledgeAsset.dump` for arguments.""" from vectorbtpro.utils.knowledge.base_asset_funcs import DumpAssetFunc return DumpAssetFunc.prepare_and_call( self.data, source=source, dump_engine=dump_engine, template_context=template_context, **kwargs, ) def to_documents(self, **kwargs) -> tp.MaybeKnowledgeAsset: """Convert to documents of type `vectorbtpro.utils.knowledge.chatting.TextDocument`. Document-related keyword arguments may contain templates. In such templates, the index of the data item is represented by "i", the data item itself is represented by "d", the data item under the path is represented by "x" while its fields are represented by their names.""" return self.apply("to_docs", **kwargs) def split_text( self, text_path: tp.Optional[tp.PathLikeKey] = None, merge_chunks: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Split text. Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.SplitTextAssetFunc`. Use argument `text_path` to specify a path to the content. If `merge_chunks` is True, merges all chunks into a single list. Uses `vectorbtpro.utils.knowledge.chatting.split_text` with `**split_text_kwargs` for text splitting.""" split_asset = self.apply( "split_text", text_path=text_path, **kwargs, ) merge_chunks = self.resolve_setting(merge_chunks, "merge_chunks") if ( merge_chunks and isinstance(split_asset, KnowledgeAsset) and len(split_asset) > 0 and isinstance(split_asset[0], list) ): split_asset = split_asset.merge() return split_asset # ############# Reduce methods ############# # @classmethod def get_keys_and_groups( cls, by: tp.List[tp.Any], uniform_groups: bool = False, ) -> tp.Tuple[tp.List[tp.Any], tp.List[tp.List[int]]]: """get keys and groups.""" keys = [] groups = [] if uniform_groups: for i, item in enumerate(by): if len(keys) > 0 and (keys[-1] is item or keys[-1] == item): groups[-1].append(i) else: keys.append(item) groups.append([i]) else: groups = [] representatives = [] for idx, item in enumerate(by): found = False for rep_idx, rep_obj in enumerate(representatives): if item is rep_obj or item == rep_obj: groups[rep_idx].append(idx) found = True break if not found: representatives.append(item) keys.append(by[idx]) groups.append([idx]) return keys, groups def reduce( self: KnowledgeAssetT, func: tp.CustomTemplateLike, *args, initializer: tp.Optional[tp.Any] = None, by: tp.Optional[tp.PathLikeKey] = None, template_context: tp.KwargsLike = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, wrap: tp.Optional[bool] = None, return_iterator: bool = False, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Reduce data items. Function can be a callable, a tuple of function and its arguments, a `vectorbtpro.utils.execution.Task` instance, a subclass of `vectorbtpro.utils.knowledge.base_asset_funcs.AssetFunc` or its prefix or full name. It can also be an expression or a template. In this template, the index of the data item is represented by "i", the data items themselves are represented by "d1" and "d2" or "x1" and "x2". If an initializer is provided, the first set of values will be `d1=initializer` and `d2=self.data[0]`. If not, it will be `d1=self.data[0]` and `d2=self.data[1]`. If `by` is provided, see `KnowledgeAsset.groupby_reduce`. If `wrap` is True, returns a new `KnowledgeAsset` instance, otherwise raw output. Usage: ```pycon >>> asset.reduce(lambda d1, d2: vbt.merge_dicts(d1, d2)) >>> asset.reduce(vbt.merge_dicts) >>> asset.reduce("{**d1, **d2}") {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123} >>> asset.reduce("{**d1, **d2}", by="b") [{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}, {'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}] ``` """ if by is not None: return self.groupby_reduce( func, *args, by=by, initializer=initializer, template_context=template_context, show_progress=show_progress, pbar_kwargs=pbar_kwargs, wrap=wrap, **kwargs, ) show_progress = self.resolve_setting(show_progress, "show_progress") pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True) asset_func_meta = {} if isinstance(func, str) and not func.isidentifier(): func = RepEval(func) elif not isinstance(func, CustomTemplate): from vectorbtpro.utils.knowledge.asset_pipelines import AssetPipeline func, args, kwargs = AssetPipeline.resolve_task( func, *args, cond_kwargs=dict(asset_cls=type(self)), asset_func_meta=asset_func_meta, **kwargs, ) it = iter(self.data) if initializer is None and asset_func_meta.get("_initializer", None) is not None: initializer = asset_func_meta["_initializer"] if initializer is None: d1 = next(it) total = len(self.data) - 1 if total == 0: raise ValueError("Must provide initializer") else: d1 = initializer total = len(self.data) def _get_d1_generator(d1): for i, d2 in enumerate(it): if isinstance(func, CustomTemplate): _template_context = flat_merge_dicts( { "i": i, "d1": d1, "d2": d2, "x1": d1, "x2": d2, }, template_context, ) _d1 = func.substitute(_template_context, eval_id="func", **kwargs) if checks.is_function(_d1): d1 = _d1(d1, d2, *args) else: d1 = _d1 else: _kwargs = dict(kwargs) if "template_context" in _kwargs: _kwargs["template_context"] = flat_merge_dicts( {"i": i}, _kwargs["template_context"], ) d1 = func(d1, d2, *args, **_kwargs) yield d1 d1s = _get_d1_generator(d1) if return_iterator: return d1s if show_progress is None: show_progress = total > 1 prefix = get_caller_qualname().split(".")[-1] if "_short_name" in asset_func_meta: prefix += f"[{asset_func_meta['_short_name']}]" elif isinstance(func, type): prefix += f"[{func.__name__}]" else: prefix += f"[{type(func).__name__}]" pbar_kwargs = flat_merge_dicts( dict( bar_id=get_caller_qualname(), prefix=prefix, ), pbar_kwargs, ) with ProgressBar(total=total, show_progress=show_progress, **pbar_kwargs) as pbar: for d1 in d1s: pbar.update() if wrap is None and asset_func_meta.get("_wrap", None) is not None: wrap = asset_func_meta["_wrap"] if wrap is None: wrap = False if wrap: if not isinstance(d1, list): d1 = [d1] return self.replace(data=d1, single_item=True) return d1 def groupby_reduce( self: KnowledgeAssetT, func: tp.CustomTemplateLike, *args, by: tp.Optional[tp.PathLikeKey] = None, uniform_groups: tp.Optional[bool] = None, get_kwargs: tp.KwargsLike = None, execute_kwargs: tp.KwargsLike = None, return_group_keys: bool = False, **kwargs, ) -> tp.Union[KnowledgeAssetT, dict, list]: """Group data items by keys and reduce. If `by` is provided, uses it as `path` in `KnowledgeAsset.get`, groups by unique values, and runs `KnowledgeAsset.reduce` on each group. Set `uniform_groups` to True to only group unique values that are located adjacent to each other. Variable arguments are passed to each call of `KnowledgeAsset.reduce`.""" uniform_groups = self.resolve_setting(uniform_groups, "uniform_groups") execute_kwargs = self.resolve_setting(execute_kwargs, "execute_kwargs", merge=True) if get_kwargs is None: get_kwargs = {} by = self.get(path=by, **get_kwargs) keys, groups = self.get_keys_and_groups(by, uniform_groups=uniform_groups) if len(groups) == 0: raise ValueError("Groups are empty") tasks = [] for i, group in enumerate(groups): group_instance = self.get_items(group) tasks.append(Task(group_instance.reduce, func, *args, **kwargs)) prefix = get_caller_qualname().split(".")[-1] execute_kwargs = merge_dicts( dict( show_progress=False if len(groups) == 1 else None, pbar_kwargs=dict( bar_id=get_caller_qualname(), prefix=prefix, ), ), execute_kwargs, ) results = execute(tasks, size=len(groups), **execute_kwargs) if return_group_keys: return dict(zip(keys, results)) if len(results) > 0 and isinstance(results[0], type(self)): return type(self).combine(results) return results def merge_dicts(self: KnowledgeAssetT, **kwargs) -> tp.MaybeKnowledgeAsset: """Merge (dict) date items into a single dict. Final keyword arguments are passed to `vectorbtpro.utils.config.merge_dicts`.""" return self.reduce("merge_dicts", **kwargs) def merge_lists(self: KnowledgeAssetT, **kwargs) -> tp.MaybeKnowledgeAsset: """Merge (list) date items into a single list.""" return self.reduce("merge_lists", **kwargs) def collect( self: KnowledgeAssetT, sort_keys: tp.Optional[bool] = None, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Collect values of each key in each data item.""" return self.reduce("collect", sort_keys=sort_keys, **kwargs) @classmethod def describe_lengths(self, lengths: list, **describe_kwargs) -> dict: """Describe values representing lengths.""" len_describe_dict = pd.Series(lengths).describe(**describe_kwargs).to_dict() del len_describe_dict["count"] del len_describe_dict["std"] return {"len_" + k: int(v) if k != "mean" else v for k, v in len_describe_dict.items()} def describe( self: KnowledgeAssetT, ignore_empty: tp.Optional[bool] = None, describe_kwargs: tp.KwargsLike = None, wrap: bool = False, **kwargs, ) -> tp.Union[KnowledgeAssetT, dict]: """Collect and describe each key in each data item.""" ignore_empty = self.resolve_setting(ignore_empty, "ignore_empty") describe_kwargs = self.resolve_setting(describe_kwargs, "describe_kwargs", merge=True) collected = self.collect(**kwargs) description = {} for k, v in list(collected.items()): all_types = [] valid_types = [] valid_x = None new_v = [] for x in v: if not ignore_empty or x: new_v.append(x) if x is not None: valid_x = x if type(x) not in valid_types: valid_types.append(type(x)) if type(x) not in all_types: all_types.append(type(x)) v = new_v description[k] = {} description[k]["types"] = list(map(lambda x: x.__name__, all_types)) describe_sr = pd.Series(v) if describe_sr.dtype == object and len(valid_types) == 1 and checks.is_complex_collection(valid_x): describe_dict = {"count": len(v)} else: describe_dict = describe_sr.describe(**describe_kwargs).to_dict() if "count" in describe_dict: describe_dict["count"] = int(describe_dict["count"]) if "unique" in describe_dict: describe_dict["unique"] = int(describe_dict["unique"]) if pd.api.types.is_integer_dtype(describe_sr.dtype): new_describe_dict = {} for _k, _v in describe_dict.items(): if _k not in {"mean", "std"}: new_describe_dict[_k] = int(_v) else: new_describe_dict[_k] = _v describe_dict = new_describe_dict if "unique" in describe_dict and describe_dict["unique"] == describe_dict["count"]: del describe_dict["top"] del describe_dict["freq"] if "unique" in describe_dict and describe_dict["count"] == 1: del describe_dict["unique"] description[k].update(describe_dict) if len(valid_types) == 1 and checks.is_collection(valid_x): lengths = [len(_v) for _v in v if _v is not None] description[k].update(self.describe_lengths(lengths, **describe_kwargs)) if wrap: return self.replace(data=[description], single_item=True) return description def print_schema(self, **kwargs) -> None: """Print schema. Keyword arguments are split between `KnowledgeAsset.describe` and `vectorbtpro.utils.path_.dir_tree_from_paths`. Usage: ```pycon >>> asset.print_schema() / ├── s [5/5, str] ├── b [2/5, bool] ├── d2 [5/5, dict] │ ├── c [5/5, str] │ └── l │ ├── 0 [5/5, int] │ └── 1 [5/5, int] └── xyz [1/5, int] 2 directories, 6 files ``` """ dir_tree_arg_names = set(get_func_arg_names(dir_tree_from_paths)) dir_tree_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in dir_tree_arg_names} orig_describe_dict = self.describe(wrap=False, **kwargs) flat_describe_dict = self.flatten( skip_missing=True, make_copy=True, changed_only=False, ).describe(wrap=False, **kwargs) describe_dict = flat_merge_dicts(orig_describe_dict, flat_describe_dict) paths = [] path_names = [] for k, v in describe_dict.items(): if k is None: k = "." if not isinstance(k, tuple): k = (k,) path = Path(*map(str, k)) path_name = path.name path_name += " [" + str(v["count"]) + "/" + str(len(self.data)) path_name += ", " + ", ".join(v["types"]) + "]" path_names.append(path_name) paths.append(path) if "root_name" not in dir_tree_kwargs: dir_tree_kwargs["root_name"] = "/" if "sort" not in dir_tree_kwargs: dir_tree_kwargs["sort"] = False if "path_names" not in dir_tree_kwargs: dir_tree_kwargs["path_names"] = path_names if "length_limit" not in dir_tree_kwargs: dir_tree_kwargs["length_limit"] = None print(dir_tree_from_paths(paths, **dir_tree_kwargs)) def join(self, separator: tp.Optional[str] = None) -> str: """Join the list of string data items.""" if len(self.data) == 0: return "" if len(self.data) == 1: return self.data[0] if separator is None: use_empty_separator = True use_comma_separator = True for d in self.data: if not d.endswith(("\n", "\t", " ")): use_empty_separator = False if not d.endswith(("}", "]")): use_comma_separator = False if not use_empty_separator and not use_comma_separator: break if use_empty_separator: separator = "" elif use_comma_separator: separator = ", " else: separator = "\n\n" joined = separator.join(self.data) if joined.startswith("{") and joined.endswith("}"): return "[" + joined + "]" return joined def embed( self, to_documents_kwargs: tp.KwargsLike = None, wrap_documents: tp.Optional[bool] = None, **kwargs, ) -> tp.Optional[tp.MaybeKnowledgeAsset]: """Embed documents. First, converts to `vectorbtpro.utils.knowledge.chatting.TextDocument` format using `KnowledgeAsset.to_documents` and `**to_documents_kwargs`. Then, uses `vectorbtpro.utils.knowledge.chatting.embed_documents` with `**kwargs` for actual ranking.""" from vectorbtpro.utils.knowledge.chatting import StoreDocument, EmbeddedDocument, embed_documents if self.data and not isinstance(self.data[0], StoreDocument): if to_documents_kwargs is None: to_documents_kwargs = {} documents = self.to_documents(**to_documents_kwargs) if wrap_documents is None: wrap_documents = False else: documents = self.data if wrap_documents is None: wrap_documents = True embedded_documents = embed_documents(documents, **kwargs) if embedded_documents is None: return None if not wrap_documents: def _unwrap(document): if isinstance(document, EmbeddedDocument): return document.replace( document=_unwrap(document.document), child_documents=[_unwrap(d) for d in document.child_documents], ) if isinstance(document, StoreDocument): return document.data return document embedded_documents = list(map(_unwrap, embedded_documents)) return self.replace(data=embedded_documents) def rank( self, query: str, to_documents_kwargs: tp.KwargsLike = None, wrap_documents: tp.Optional[bool] = None, cache_documents: bool = False, cache_key: tp.Optional[str] = None, asset_cache_manager: tp.Optional[tp.MaybeType[AssetCacheManager]] = None, asset_cache_manager_kwargs: tp.KwargsLike = None, silence_warnings: bool = False, **kwargs, ) -> tp.MaybeKnowledgeAsset: """Rank documents by their similarity to a query. First, converts to `vectorbtpro.utils.knowledge.chatting.TextDocument` format using `KnowledgeAsset.to_documents` and `**to_documents_kwargs`. Then, uses `vectorbtpro.utils.knowledge.chatting.rank_documents` with `**kwargs` for actual ranking. If `cache_documents` is True and `cache_key` is not None, will use an asset cache manager to store the generated text documents in a local and/or disk cache after conversion. Running the same method again will use the cached documents.""" from vectorbtpro.utils.knowledge.chatting import StoreDocument, ScoredDocument, rank_documents if cache_documents: if asset_cache_manager is None: asset_cache_manager = AssetCacheManager if asset_cache_manager_kwargs is None: asset_cache_manager_kwargs = {} if isinstance(asset_cache_manager, type): checks.assert_subclass_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager") asset_cache_manager = asset_cache_manager(**asset_cache_manager_kwargs) else: checks.assert_instance_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager") if asset_cache_manager_kwargs: asset_cache_manager = asset_cache_manager.replace(**asset_cache_manager_kwargs) documents = None if cache_documents and cache_key is not None: documents = asset_cache_manager.load_asset(cache_key) if documents is not None: if wrap_documents is None: wrap_documents = False else: if not silence_warnings: warn("Caching documents...") if documents is None: if self.data and not isinstance(self.data[0], StoreDocument): if to_documents_kwargs is None: to_documents_kwargs = {} documents = self.to_documents(**to_documents_kwargs) if cache_documents and cache_key is not None and isinstance(documents, KnowledgeAsset): asset_cache_manager.save_asset(documents, cache_key) if wrap_documents is None: wrap_documents = False else: documents = self.data if wrap_documents is None: wrap_documents = True ranked_documents = rank_documents(query=query, documents=documents, **kwargs) if not wrap_documents: def _unwrap(document): if isinstance(document, ScoredDocument): return document.replace( document=_unwrap(document.document), child_documents=[_unwrap(d) for d in document.child_documents], ) if isinstance(document, StoreDocument): return document.data return document ranked_documents = list(map(_unwrap, ranked_documents)) return self.replace(data=ranked_documents) def to_context( self, *args, dump_all: tp.Optional[bool] = None, separator: tp.Optional[str] = None, **kwargs, ) -> str: """Convert to a context. If `dump_all` is True, calls `KnowledgeAsset.dump_all` with `*args` and `**kwargs`. Otherwise, calls `KnowledgeAsset.dump`. Finally, calls `KnowledgeAsset.join` with `separator`.""" from vectorbtpro.utils.knowledge.chatting import StoreDocument, EmbeddedDocument, ScoredDocument if dump_all is None: dump_all = ( len(self.data) > 1 and not isinstance(self.data[0], (StoreDocument, EmbeddedDocument, ScoredDocument)) and separator is None ) if dump_all: dumped = self.dump_all(*args, **kwargs) else: dumped = self.dump(*args, **kwargs) if isinstance(dumped, str): return dumped if not isinstance(dumped, KnowledgeAsset): dumped = self.replace(data=dumped) return dumped.join(separator=separator) def print(self, *args, **kwargs) -> None: """Convert to a context and print. Uses `KnowledgeAsset.to_context`.""" print(self.to_context(*args, **kwargs)) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for chatting. See `vectorbtpro.utils.knowledge` for the toy dataset.""" import hashlib import inspect import re import sys from pathlib import Path from collections.abc import MutableMapping import numpy as np from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.attr_ import DefineMixin, define from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts, Configured, HasSettings, ExtSettingsPath from vectorbtpro.utils.decorators import memoized_method, hybrid_method from vectorbtpro.utils.knowledge.formatting import ContentFormatter, HTMLFileFormatter, resolve_formatter from vectorbtpro.utils.parsing import get_func_arg_names, get_func_kwargs, get_forward_args from vectorbtpro.utils.template import CustomTemplate, SafeSub, RepFunc from vectorbtpro.utils.warnings_ import warn try: if not tp.TYPE_CHECKING: raise ImportError from tiktoken import Encoding as EncodingT except ImportError: EncodingT = "Encoding" try: if not tp.TYPE_CHECKING: raise ImportError from openai import OpenAI as OpenAIT, Stream as StreamT from openai.types.chat.chat_completion import ChatCompletion as ChatCompletionT from openai.types.chat.chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunkT except ImportError: OpenAIT = "OpenAI" StreamT = "Stream" ChatCompletionT = "ChatCompletion" ChatCompletionChunkT = "ChatCompletionChunk" try: if not tp.TYPE_CHECKING: raise ImportError from litellm import ModelResponse as ModelResponseT, CustomStreamWrapper as CustomStreamWrapperT except ImportError: ModelResponseT = "ModelResponse" CustomStreamWrapperT = "CustomStreamWrapper" try: if not tp.TYPE_CHECKING: raise ImportError from llama_index.core.embeddings import BaseEmbedding as BaseEmbeddingT from llama_index.core.llms import LLM as LLMT, ChatMessage as ChatMessageT, ChatResponse as ChatResponseT from llama_index.core.node_parser import NodeParser as NodeParserT except ImportError: BaseEmbeddingT = "BaseEmbedding" LLMT = "LLM" ChatMessageT = "ChatMessage" ChatResponseT = "ChatResponse" NodeParserT = "NodeParser" try: if not tp.TYPE_CHECKING: raise ImportError from IPython.display import DisplayHandle as DisplayHandleT except ImportError: DisplayHandleT = "DisplayHandle" try: if not tp.TYPE_CHECKING: raise ImportError from lmdbm import Lmdb as LmdbT except ImportError: LmdbT = "Lmdb" __all__ = [ "Tokenizer", "TikTokenizer", "tokenize", "detokenize", "Embeddings", "OpenAIEmbeddings", "LiteLLMEmbeddings", "LlamaIndexEmbeddings", "embed", "Completions", "OpenAICompletions", "LiteLLMCompletions", "LlamaIndexCompletions", "complete", "TextSplitter", "TokenSplitter", "SegmentSplitter", "LlamaIndexSplitter", "split_text", "StoreObject", "StoreDocument", "TextDocument", "StoreEmbedding", "ObjectStore", "DictStore", "MemoryStore", "FileStore", "LMDBStore", "EmbeddedDocument", "ScoredDocument", "DocumentRanker", "embed_documents", "rank_documents", "Rankable", "Contextable", "RankContextable", ] # ############# Tokenizers ############# # class Tokenizer(Configured): """Abstract class for tokenizers. For defaults, see `knowledge.chat.tokenizer_config` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = None """Short name of the class.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.tokenizer_config"] def __init__(self, template_context: tp.KwargsLike = None, **kwargs) -> None: Configured.__init__(self, template_context=template_context, **kwargs) template_context = self.resolve_setting(template_context, "template_context", merge=True) self._template_context = template_context @property def template_context(self) -> tp.Kwargs: """Context used to substitute templates.""" return self._template_context def encode(self, text: str) -> tp.Tokens: """Encode text into a list of tokens.""" raise NotImplementedError def decode(self, tokens: tp.Tokens) -> str: """Decode a list of tokens into text.""" raise NotImplementedError @memoized_method def encode_single(self, text: str) -> tp.Token: """Encode text into a single token.""" tokens = self.encode(text) if len(tokens) > 1: raise ValueError("Text contains multiple tokens") return tokens[0] @memoized_method def decode_single(self, token: tp.Token) -> str: """Decode a single token into text.""" return self.decode([token]) def count_tokens(self, text: str) -> int: """Count tokens in a text.""" return len(self.encode(text)) def count_tokens_in_messages(self, messages: tp.ChatMessages) -> int: """Count tokens in messages.""" raise NotImplementedError class TikTokenizer(Tokenizer): """Tokenizer class for tiktoken. Encoding can be a model name, an encoding name, or an encoding object for tokenization. For defaults, see `chat.tokenizer_configs.tiktoken` in `vectorbtpro._settings.knowledge`.""" _short_name = "tiktoken" _settings_path: tp.SettingsPath = "knowledge.chat.tokenizer_configs.tiktoken" def __init__( self, encoding: tp.Union[None, str, EncodingT] = None, model: tp.Optional[str] = None, tokens_per_message: tp.Optional[int] = None, tokens_per_name: tp.Optional[int] = None, **kwargs, ) -> None: Tokenizer.__init__( self, encoding=encoding, model=model, tokens_per_message=tokens_per_message, tokens_per_name=tokens_per_name, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("tiktoken") from tiktoken import Encoding, get_encoding, encoding_for_model encoding = self.resolve_setting(encoding, "encoding") model = self.resolve_setting(model, "model") tokens_per_message = self.resolve_setting(tokens_per_message, "tokens_per_message") tokens_per_name = self.resolve_setting(tokens_per_name, "tokens_per_name") if isinstance(encoding, str): if encoding.startswith("model_or_"): try: if model is None: raise KeyError encoding = encoding_for_model(model) except KeyError: encoding = encoding[len("model_or_") :] encoding = get_encoding(encoding) if "k_base" in encoding else encoding_for_model(encoding) elif isinstance(encoding, str): encoding = get_encoding(encoding) if "k_base" in encoding else encoding_for_model(encoding) checks.assert_instance_of(encoding, Encoding, arg_name="encoding") self._encoding = encoding self._tokens_per_message = tokens_per_message self._tokens_per_name = tokens_per_name @property def encoding(self) -> EncodingT: """Encoding.""" return self._encoding @property def tokens_per_message(self) -> int: """Tokens per message.""" return self._tokens_per_message @property def tokens_per_name(self) -> int: """Tokens per name.""" return self._tokens_per_name def encode(self, text: str) -> tp.Tokens: return self.encoding.encode(text) def decode(self, tokens: tp.Tokens) -> str: return self.encoding.decode(tokens) def count_tokens_in_messages(self, messages: tp.ChatMessages) -> int: num_tokens = 0 for message in messages: num_tokens += self.tokens_per_message for key, value in message.items(): num_tokens += self.count_tokens(value) if key == "name": num_tokens += self.tokens_per_name num_tokens += 3 return num_tokens def resolve_tokenizer(tokenizer: tp.TokenizerLike = None) -> tp.MaybeType[Tokenizer]: """Resolve a subclass or an instance of `Tokenizer`. The following values are supported: * "tiktoken" (`TikTokenizer`) * A subclass or an instance of `Tokenizer` """ if tokenizer is None: from vectorbtpro._settings import settings chat_cfg = settings["knowledge"]["chat"] tokenizer = chat_cfg["tokenizer"] if isinstance(tokenizer, str): curr_module = sys.modules[__name__] found_tokenizer = None for name, cls in inspect.getmembers(curr_module, inspect.isclass): if name.endswith("Tokenizer"): _short_name = getattr(cls, "_short_name", None) if _short_name is not None and _short_name.lower() == tokenizer.lower(): found_tokenizer = cls break if found_tokenizer is None: raise ValueError(f"Invalid tokenizer: '{tokenizer}'") tokenizer = found_tokenizer if isinstance(tokenizer, type): checks.assert_subclass_of(tokenizer, Tokenizer, arg_name="tokenizer") else: checks.assert_instance_of(tokenizer, Tokenizer, arg_name="tokenizer") return tokenizer def tokenize(text: str, tokenizer: tp.TokenizerLike = None, **kwargs) -> tp.Tokens: """Tokenize text. Resolves `tokenizer` with `resolve_tokenizer`. Keyword arguments are passed to either initialize a class or replace an instance of `Tokenizer`.""" tokenizer = resolve_tokenizer(tokenizer=tokenizer) if isinstance(tokenizer, type): tokenizer = tokenizer(**kwargs) elif kwargs: tokenizer = tokenizer.replace(**kwargs) return tokenizer.encode(text) def detokenize(tokens: tp.Tokens, tokenizer: tp.TokenizerLike = None, **kwargs) -> str: """Detokenize text. Resolves `tokenizer` with `resolve_tokenizer`. Keyword arguments are passed to either initialize a class or replace an instance of `Tokenizer`.""" tokenizer = resolve_tokenizer(tokenizer=tokenizer) if isinstance(tokenizer, type): tokenizer = tokenizer(**kwargs) elif kwargs: tokenizer = tokenizer.replace(**kwargs) return tokenizer.decode(tokens) # ############# Embeddings ############# # class Embeddings(Configured): """Abstract class for embedding providers. For defaults, see `knowledge.chat.embeddings_config` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = None """Short name of the class.""" _expected_keys_mode: tp.ExpectedKeysMode = "disable" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.embeddings_config"] def __init__( self, batch_size: tp.Optional[int] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Configured.__init__( self, batch_size=batch_size, show_progress=show_progress, pbar_kwargs=pbar_kwargs, template_context=template_context, **kwargs, ) batch_size = self.resolve_setting(batch_size, "batch_size") show_progress = self.resolve_setting(show_progress, "show_progress") pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True) template_context = self.resolve_setting(template_context, "template_context", merge=True) self._batch_size = batch_size self._show_progress = show_progress self._pbar_kwargs = pbar_kwargs self._template_context = template_context @property def batch_size(self) -> tp.Optional[int]: """Batch size. Set to None to disable batching.""" return self._batch_size @property def show_progress(self) -> tp.Optional[bool]: """Whether to show progress bar.""" return self._show_progress @property def pbar_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.""" return self._pbar_kwargs @property def template_context(self) -> tp.Kwargs: """Context used to substitute templates.""" return self._template_context @property def model(self) -> tp.Optional[str]: """Model.""" return None def get_embedding(self, query: str) -> tp.List[float]: """Get embedding for a query.""" raise NotImplementedError def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]: """Get embeddings for one batch of queries.""" return [self.get_embedding(query) for query in batch] def iter_embedding_batches(self, queries: tp.List[str]) -> tp.Iterator[tp.List[tp.List[float]]]: """Get iterator of embedding batches.""" from vectorbtpro.utils.pbar import ProgressBar if self.batch_size is not None: batches = [queries[i : i + self.batch_size] for i in range(0, len(queries), self.batch_size)] else: batches = [queries] pbar_kwargs = merge_dicts(dict(prefix="get_embeddings"), self.pbar_kwargs) with ProgressBar(total=len(queries), show_progress=self.show_progress, **pbar_kwargs) as pbar: for batch in batches: yield self.get_embedding_batch(batch) pbar.update(len(batch)) def get_embeddings(self, queries: tp.List[str]) -> tp.List[tp.List[float]]: """Get embeddings for multiple queries.""" return [embedding for batch in self.iter_embedding_batches(queries) for embedding in batch] class OpenAIEmbeddings(Embeddings): """Embeddings class for OpenAI. For defaults, see `chat.embeddings_configs.openai` in `vectorbtpro._settings.knowledge`.""" _short_name = "openai" _settings_path: tp.SettingsPath = "knowledge.chat.embeddings_configs.openai" def __init__( self, model: tp.Optional[str] = None, batch_size: tp.Optional[int] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Embeddings.__init__( self, model=model, batch_size=batch_size, show_progress=show_progress, pbar_kwargs=pbar_kwargs, template_context=template_context, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("openai") from openai import OpenAI openai_config = merge_dicts(self.get_settings(inherit=False), kwargs) def_model = openai_config.pop("model", None) if model is None: model = def_model if model is None: raise ValueError("Must provide a model") init_kwargs = get_func_kwargs(type(self).__init__) for k in list(openai_config.keys()): if k in init_kwargs: openai_config.pop(k) client_arg_names = set(get_func_arg_names(OpenAI.__init__)) client_kwargs = {} embeddings_kwargs = {} for k, v in openai_config.items(): if k in client_arg_names: client_kwargs[k] = v else: embeddings_kwargs[k] = v client = OpenAI(**client_kwargs) self._model = model self._client = client self._embeddings_kwargs = embeddings_kwargs @property def model(self) -> str: return self._model @property def client(self) -> OpenAIT: """Client.""" return self._client @property def embeddings_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `openai.resources.embeddings.Embeddings.create`.""" return self._embeddings_kwargs def get_embedding(self, query: str) -> tp.List[float]: response = self.client.embeddings.create(input=query, model=self.model, **self.embeddings_kwargs) return response.data[0].embedding def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]: response = self.client.embeddings.create(input=batch, model=self.model, **self.embeddings_kwargs) return [embedding.embedding for embedding in response.data] class LiteLLMEmbeddings(Embeddings): """Embeddings class for LiteLLM. For defaults, see `chat.embeddings_configs.litellm` in `vectorbtpro._settings.knowledge`.""" _short_name = "litellm" _settings_path: tp.SettingsPath = "knowledge.chat.embeddings_configs.litellm" def __init__( self, model: tp.Optional[str] = None, batch_size: tp.Optional[int] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Embeddings.__init__( self, model=model, batch_size=batch_size, show_progress=show_progress, pbar_kwargs=pbar_kwargs, template_context=template_context, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("litellm") litellm_config = merge_dicts(self.get_settings(inherit=False), kwargs) def_model = litellm_config.pop("model", None) if model is None: model = def_model if model is None: raise ValueError("Must provide a model") init_kwargs = get_func_kwargs(type(self).__init__) for k in list(litellm_config.keys()): if k in init_kwargs: litellm_config.pop(k) self._model = model self._embedding_kwargs = litellm_config @property def model(self) -> str: return self._model @property def embedding_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `litellm.embedding`.""" return self._embedding_kwargs def get_embedding(self, query: str) -> tp.List[float]: from litellm import embedding response = embedding(self.model, input=query, **self.embedding_kwargs) return response.data[0]["embedding"] def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]: from litellm import embedding response = embedding(self.model, input=batch, **self.embedding_kwargs) return [embedding["embedding"] for embedding in response.data] class LlamaIndexEmbeddings(Embeddings): """Embeddings class for LlamaIndex. For defaults, see `chat.embeddings_configs.llama_index` in `vectorbtpro._settings.knowledge`.""" _short_name = "llama_index" _settings_path: tp.SettingsPath = "knowledge.chat.embeddings_configs.llama_index" def __init__( self, embedding: tp.Union[None, str, tp.MaybeType[BaseEmbeddingT]] = None, batch_size: tp.Optional[int] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Embeddings.__init__( self, embedding=embedding, batch_size=batch_size, show_progress=show_progress, pbar_kwargs=pbar_kwargs, template_context=template_context, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("llama_index") from llama_index.core.embeddings import BaseEmbedding llama_index_config = merge_dicts(self.get_settings(inherit=False), kwargs) def_embedding = llama_index_config.pop("embedding", None) if embedding is None: embedding = def_embedding if embedding is None: raise ValueError("Must provide an embedding name or path") init_kwargs = get_func_kwargs(type(self).__init__) for k in list(llama_index_config.keys()): if k in init_kwargs: llama_index_config.pop(k) if isinstance(embedding, str): import llama_index.embeddings from vectorbtpro.utils.module_ import search_package def _match_func(k, v): if isinstance(v, type) and issubclass(v, BaseEmbedding): if "." in embedding: if k.endswith(embedding): return True else: if k.split(".")[-1].lower() == embedding.lower(): return True if k.split(".")[-1].replace("Embedding", "").lower() == embedding.lower().replace("_", ""): return True return False found_embedding = search_package( llama_index.embeddings, _match_func, path_attrs=True, return_first=True, ) if found_embedding is None: raise ValueError(f"Embedding '{embedding}' not found") embedding = found_embedding if isinstance(embedding, type): checks.assert_subclass_of(embedding, BaseEmbedding, arg_name="embedding") embedding_name = embedding.__name__.replace("Embedding", "").lower() module_name = embedding.__module__ else: checks.assert_instance_of(embedding, BaseEmbedding, arg_name="embedding") embedding_name = type(embedding).__name__.replace("Embedding", "").lower() module_name = type(embedding).__module__ embedding_configs = llama_index_config.pop("embedding_configs", {}) if embedding_name in embedding_configs: llama_index_config = merge_dicts(llama_index_config, embedding_configs[embedding_name]) elif module_name in embedding_configs: llama_index_config = merge_dicts(llama_index_config, embedding_configs[module_name]) if isinstance(embedding, type): embedding = embedding(**llama_index_config) elif len(kwargs) > 0: raise ValueError("Cannot apply config to already initialized embedding") model_name = llama_index_config.get("model_name", None) if model_name is None: func_kwargs = get_func_kwargs(type(embedding).__init__) model_name = func_kwargs.get("model_name", None) self._model = model_name self._embedding = embedding @property def model(self) -> tp.Optional[str]: return self._model @property def embedding(self) -> BaseEmbeddingT: """Embedding.""" return self._embedding def get_embedding(self, query: str) -> tp.List[float]: return self.embedding.get_text_embedding(query) def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]: return [embedding for embedding in self.embedding.get_text_embedding_batch(batch)] def resolve_embeddings(embeddings: tp.EmbeddingsLike = None) -> tp.MaybeType[Embeddings]: """Resolve a subclass or an instance of `Embeddings`. The following values are supported: * "openai" (`OpenAIEmbeddings`) * "litellm" (`LiteLLMEmbeddings`) * "llama_index" (`LlamaIndexEmbeddings`) * "auto": Any installed from above, in the same order * A subclass or an instance of `Embeddings` """ if embeddings is None: from vectorbtpro._settings import settings chat_cfg = settings["knowledge"]["chat"] embeddings = chat_cfg["embeddings"] if isinstance(embeddings, str): if embeddings.lower() == "auto": from vectorbtpro.utils.module_ import check_installed if check_installed("openai"): embeddings = "openai" elif check_installed("litellm"): embeddings = "litellm" elif check_installed("llama_index"): embeddings = "llama_index" else: raise ValueError("No packages for embeddings installed") curr_module = sys.modules[__name__] found_embeddings = None for name, cls in inspect.getmembers(curr_module, inspect.isclass): if name.endswith("Embeddings"): _short_name = getattr(cls, "_short_name", None) if _short_name is not None and _short_name.lower() == embeddings.lower(): found_embeddings = cls break if found_embeddings is None: raise ValueError(f"Invalid embeddings: '{embeddings}'") embeddings = found_embeddings if isinstance(embeddings, type): checks.assert_subclass_of(embeddings, Embeddings, arg_name="embeddings") else: checks.assert_instance_of(embeddings, Embeddings, arg_name="embeddings") return embeddings def embed(query: tp.MaybeList[str], embeddings: tp.EmbeddingsLike = None, **kwargs) -> tp.MaybeList[tp.List[float]]: """Get embedding(s) for one or more queries. Resolves `embeddings` with `resolve_embeddings`. Keyword arguments are passed to either initialize a class or replace an instance of `Embeddings`.""" embeddings = resolve_embeddings(embeddings=embeddings) if isinstance(embeddings, type): embeddings = embeddings(**kwargs) elif kwargs: embeddings = embeddings.replace(**kwargs) if isinstance(query, str): return embeddings.get_embedding(query) return embeddings.get_embeddings(query) # ############# Completions ############# # class Completions(Configured): """Abstract class for completion providers. For argument descriptions, see their properties, like `Completions.chat_history`. For defaults, see `knowledge.chat.completions_config` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = None """Short name of the class.""" _expected_keys_mode: tp.ExpectedKeysMode = "disable" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.completions_config"] def __init__( self, context: str = "", chat_history: tp.Optional[tp.ChatHistory] = None, stream: tp.Optional[bool] = None, max_tokens: tp.Optional[int] = None, tokenizer: tp.TokenizerLike = None, tokenizer_kwargs: tp.KwargsLike = None, system_prompt: tp.Optional[str] = None, system_as_user: tp.Optional[bool] = None, context_prompt: tp.Optional[str] = None, formatter: tp.ContentFormatterLike = None, formatter_kwargs: tp.KwargsLike = None, minimal_format: tp.Optional[bool] = None, silence_warnings: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Configured.__init__( self, context=context, chat_history=chat_history, stream=stream, max_tokens=max_tokens, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs, system_prompt=system_prompt, system_as_user=system_as_user, context_prompt=context_prompt, formatter=formatter, formatter_kwargs=formatter_kwargs, minimal_format=minimal_format, silence_warnings=silence_warnings, template_context=template_context, **kwargs, ) if chat_history is None: chat_history = [] stream = self.resolve_setting(stream, "stream") max_tokens_set = max_tokens is not None max_tokens = self.resolve_setting(max_tokens, "max_tokens") tokenizer = self.resolve_setting(tokenizer, "tokenizer", default=None) tokenizer_kwargs = self.resolve_setting(tokenizer_kwargs, "tokenizer_kwargs", default=None, merge=True) system_prompt = self.resolve_setting(system_prompt, "system_prompt") system_as_user = self.resolve_setting(system_as_user, "system_as_user") context_prompt = self.resolve_setting(context_prompt, "context_prompt") formatter = self.resolve_setting(formatter, "formatter", default=None) formatter_kwargs = self.resolve_setting(formatter_kwargs, "formatter_kwargs", default=None, merge=True) minimal_format = self.resolve_setting(minimal_format, "minimal_format", default=None) silence_warnings = self.resolve_setting(silence_warnings, "silence_warnings") template_context = self.resolve_setting(template_context, "template_context", merge=True) tokenizer = resolve_tokenizer(tokenizer) formatter = resolve_formatter(formatter) self._context = context self._chat_history = chat_history self._stream = stream self._max_tokens_set = max_tokens_set self._max_tokens = max_tokens self._tokenizer = tokenizer self._tokenizer_kwargs = tokenizer_kwargs self._system_prompt = system_prompt self._system_as_user = system_as_user self._context_prompt = context_prompt self._formatter = formatter self._formatter_kwargs = formatter_kwargs self._minimal_format = minimal_format self._silence_warnings = silence_warnings self._template_context = template_context @property def context(self) -> str: """Context. Becomes a user message.""" return self._context @property def chat_history(self) -> tp.ChatHistory: """Chat history. Must be list of dictionaries with proper roles. After generating a response, the output will be appended to this sequence as an assistant message.""" return self._chat_history @property def stream(self) -> bool: """Whether to stream the response. When streaming, appends chunks one by one and displays the intermediate result. Otherwise, displays the entire message.""" return self._stream @property def max_tokens_set(self) -> tp.Optional[int]: """Whether the user provided `max_tokens`.""" return self._max_tokens_set @property def max_tokens(self) -> tp.Optional[int]: """Maximum number of tokens in messages.""" return self._max_tokens @property def tokenizer(self) -> tp.MaybeType[Tokenizer]: """A subclass or an instance of `Tokenizer`. Resolved with `resolve_tokenizer`.""" return self._tokenizer @property def tokenizer_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `Completions.tokenizer`. Used either to initialize a class or replace an instance of `Tokenizer`.""" return self._tokenizer_kwargs @property def system_prompt(self) -> str: """System prompt. Precedes the context prompt.""" return self._system_prompt @property def system_as_user(self) -> bool: """Whether to use the user role for the system message. Mainly for experimental models where the system role is not available.""" return self._system_as_user @property def context_prompt(self) -> str: """Context prompt. A prompt template requiring the variable "context". The prompt can be either a custom template, or string or function that will become one. Once the prompt is evaluated, it becomes a user message.""" return self._context_prompt @property def formatter(self) -> tp.MaybeType[ContentFormatter]: """A subclass or an instance of `vectorbtpro.utils.knowledge.formatting.ContentFormatter`. Resolved with `vectorbtpro.utils.knowledge.formatting.resolve_formatter`.""" return self._formatter @property def formatter_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `Completions.formatter`. Used either to initialize a class or replace an instance of `vectorbtpro.utils.knowledge.formatting.ContentFormatter`.""" return self._formatter_kwargs @property def minimal_format(self) -> bool: """Whether input is minimally-formatted.""" return self._minimal_format @property def silence_warnings(self) -> bool: """Whether to silence warnings.""" return self._silence_warnings @property def template_context(self) -> tp.Kwargs: """Context used to substitute templates.""" return self._template_context @property def model(self) -> tp.Optional[str]: """Model.""" return None def get_chat_response(self, messages: tp.ChatMessages, **kwargs) -> tp.Any: """Get chat response to messages.""" raise NotImplementedError def get_message_content(self, response: tp.Any) -> tp.Optional[str]: """Get content from a chat response.""" raise NotImplementedError def get_stream_response(self, messages: tp.ChatMessages, **kwargs) -> tp.Any: """Get streaming response to messages.""" raise NotImplementedError def get_delta_content(self, response: tp.Any) -> tp.Optional[str]: """Get content from a streaming response chunk.""" raise NotImplementedError def prepare_messages(self, message: str) -> tp.ChatMessages: """Prepare messages for a completion.""" context = self.context chat_history = self.chat_history max_tokens_set = self.max_tokens_set max_tokens = self.max_tokens tokenizer = self.tokenizer tokenizer_kwargs = self.tokenizer_kwargs system_prompt = self.system_prompt system_as_user = self.system_as_user context_prompt = self.context_prompt template_context = self.template_context silence_warnings = self.silence_warnings if isinstance(tokenizer, type): tokenizer_kwargs = dict(tokenizer_kwargs) tokenizer_kwargs["template_context"] = merge_dicts( template_context, tokenizer_kwargs.get("template_context", None) ) if issubclass(tokenizer, TikTokenizer) and "model" not in tokenizer_kwargs: tokenizer_kwargs["model"] = self.model tokenizer = tokenizer(**tokenizer_kwargs) elif tokenizer_kwargs: tokenizer = tokenizer.replace(**tokenizer_kwargs) if context: if isinstance(context_prompt, str): context_prompt = SafeSub(context_prompt) elif checks.is_function(context_prompt): context_prompt = RepFunc(context_prompt) elif not isinstance(context_prompt, CustomTemplate): raise TypeError(f"Context prompt must be a string, function, or template") if max_tokens is not None: empty_context_prompt = context_prompt.substitute( flat_merge_dicts(dict(context=""), template_context), eval_id="context_prompt", ) empty_messages = [ dict(role="user" if system_as_user else "system", content=system_prompt), dict(role="user", content=empty_context_prompt), *chat_history, dict(role="user", content=message), ] num_tokens = tokenizer.count_tokens_in_messages(empty_messages) max_context_tokens = max(0, max_tokens - num_tokens) encoded_context = tokenizer.encode(context) if len(encoded_context) > max_context_tokens: context = tokenizer.decode(encoded_context[:max_context_tokens]) if not max_tokens_set and not silence_warnings: warn( f"Context is too long ({len(encoded_context)}). " f"Truncating to {max_context_tokens} tokens." ) template_context = flat_merge_dicts(dict(context=context), template_context) context_prompt = context_prompt.substitute(template_context, eval_id="context_prompt") return [ dict(role="user" if system_as_user else "system", content=system_prompt), dict(role="user", content=context_prompt), *chat_history, dict(role="user", content=message), ] else: return [ dict(role="user" if system_as_user else "system", content=system_prompt), *chat_history, dict(role="user", content=message), ] def get_completion( self, message: str, return_response: bool = False, ) -> tp.ChatOutput: """Get completion for a message.""" chat_history = self.chat_history stream = self.stream formatter = self.formatter formatter_kwargs = self.formatter_kwargs template_context = self.template_context messages = self.prepare_messages(message) if self.stream: response = self.get_stream_response(messages) else: response = self.get_chat_response(messages) if isinstance(formatter, type): formatter_kwargs = dict(formatter_kwargs) if "minimal_format" not in formatter_kwargs: formatter_kwargs["minimal_format"] = self.minimal_format formatter_kwargs["template_context"] = merge_dicts( template_context, formatter_kwargs.get("template_context", None) ) if issubclass(formatter, HTMLFileFormatter): if "page_title" not in formatter_kwargs: formatter_kwargs["page_title"] = message if "cache_dir" not in formatter_kwargs: chat_dir = self.get_setting("chat_dir", default=None) if isinstance(chat_dir, CustomTemplate): cache_dir = self.get_setting("cache_dir", default=None) if cache_dir is not None: if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) release_dir = self.get_setting("release_dir", default=None) if release_dir is not None: if isinstance(release_dir, CustomTemplate): release_dir = release_dir.substitute(template_context, eval_id="release_dir") template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context) chat_dir = chat_dir.substitute(template_context, eval_id="chat_dir") chat_dir = Path(chat_dir) / "html" formatter_kwargs["dir_path"] = chat_dir formatter = formatter(**formatter_kwargs) elif formatter_kwargs: formatter = formatter.replace(**formatter_kwargs) if stream: with formatter: for i, response_chunk in enumerate(response): new_content = self.get_delta_content(response_chunk) if new_content is not None: formatter.append(new_content) content = formatter.content else: content = self.get_message_content(response) if content is None: content = "" formatter.append_once(content) chat_history.append(dict(role="user", content=message)) chat_history.append(dict(role="assistant", content=content)) if isinstance(formatter, HTMLFileFormatter) and formatter.file_handle is not None: file_path = Path(formatter.file_handle.name) else: file_path = None if return_response: return file_path, response return file_path class OpenAICompletions(Completions): """Completions class for OpenAI. Keyword arguments are distributed between the client call and the completion call. For defaults, see `chat.completions_configs.openai` in `vectorbtpro._settings.knowledge`.""" _short_name = "openai" _settings_path: tp.SettingsPath = "knowledge.chat.completions_configs.openai" def __init__( self, context: str = "", chat_history: tp.Optional[tp.ChatHistory] = None, stream: tp.Optional[bool] = None, max_tokens: tp.Optional[int] = None, tokenizer: tp.TokenizerLike = None, tokenizer_kwargs: tp.KwargsLike = None, system_prompt: tp.Optional[str] = None, system_as_user: tp.Optional[bool] = None, context_prompt: tp.Optional[str] = None, formatter: tp.ContentFormatterLike = None, formatter_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, model: tp.Optional[str] = None, **kwargs, ) -> None: Completions.__init__( self, context=context, chat_history=chat_history, stream=stream, max_tokens=max_tokens, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs, system_prompt=system_prompt, system_as_user=system_as_user, context_prompt=context_prompt, formatter=formatter, formatter_kwargs=formatter_kwargs, silence_warnings=silence_warnings, template_context=template_context, model=model, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("openai") from openai import OpenAI openai_config = merge_dicts(self.get_settings(inherit=False), kwargs) def_model = openai_config.pop("model", None) if model is None: model = def_model if model is None: raise ValueError("Must provide a model") init_kwargs = get_func_kwargs(type(self).__init__) for k in list(openai_config.keys()): if k in init_kwargs: openai_config.pop(k) client_arg_names = set(get_func_arg_names(OpenAI.__init__)) client_kwargs = {} completion_kwargs = {} for k, v in openai_config.items(): if k in client_arg_names: client_kwargs[k] = v else: completion_kwargs[k] = v client = OpenAI(**client_kwargs) self._model = model self._client = client self._completion_kwargs = completion_kwargs @property def model(self) -> str: return self._model @property def client(self) -> OpenAIT: """Client.""" return self._client @property def completion_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `openai.resources.chat.completions_configs.Completions.create`.""" return self._completion_kwargs def get_chat_response(self, messages: tp.ChatMessages) -> ChatCompletionT: return self.client.chat.completions.create( messages=messages, model=self.model, stream=False, **self.completion_kwargs, ) def get_message_content(self, response: ChatCompletionT) -> tp.Optional[str]: return response.choices[0].message.content def get_stream_response(self, messages: tp.ChatMessages) -> StreamT: return self.client.chat.completions.create( messages=messages, model=self.model, stream=True, **self.completion_kwargs, ) def get_delta_content(self, response_chunk: ChatCompletionChunkT) -> tp.Optional[str]: return response_chunk.choices[0].delta.content class LiteLLMCompletions(Completions): """Completions class for LiteLLM. Keyword arguments are passed to the completion call. For defaults, see `chat.completions_configs.litellm` in `vectorbtpro._settings.knowledge`.""" _short_name = "litellm" _settings_path: tp.SettingsPath = "knowledge.chat.completions_configs.litellm" def __init__( self, context: str = "", chat_history: tp.Optional[tp.ChatHistory] = None, stream: tp.Optional[bool] = None, max_tokens: tp.Optional[int] = None, tokenizer: tp.TokenizerLike = None, tokenizer_kwargs: tp.KwargsLike = None, system_prompt: tp.Optional[str] = None, system_as_user: tp.Optional[bool] = None, context_prompt: tp.Optional[str] = None, formatter: tp.ContentFormatterLike = None, formatter_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, model: tp.Optional[str] = None, **kwargs, ) -> None: Completions.__init__( self, context=context, chat_history=chat_history, stream=stream, max_tokens=max_tokens, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs, system_prompt=system_prompt, system_as_user=system_as_user, context_prompt=context_prompt, formatter=formatter, formatter_kwargs=formatter_kwargs, silence_warnings=silence_warnings, template_context=template_context, model=model, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("litellm") completion_kwargs = merge_dicts(self.get_settings(inherit=False), kwargs) def_model = completion_kwargs.pop("model", None) if model is None: model = def_model if model is None: raise ValueError("Must provide a model") self._model = model self._completion_kwargs = completion_kwargs @property def model(self) -> str: return self._model @property def completion_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `litellm.completion`.""" return self._completion_kwargs def get_chat_response(self, messages: tp.ChatMessages) -> ModelResponseT: from litellm import completion return completion( messages=messages, model=self.model, stream=False, **self.completion_kwargs, ) def get_message_content(self, response: ModelResponseT) -> tp.Optional[str]: return response.choices[0].message.content def get_stream_response(self, messages: tp.ChatMessages) -> CustomStreamWrapperT: from litellm import completion return completion( messages=messages, model=self.model, stream=True, **self.completion_kwargs, ) def get_delta_content(self, response_chunk: ModelResponseT) -> tp.Optional[str]: return response_chunk.choices[0].delta.content class LlamaIndexCompletions(Completions): """Completions class for LlamaIndex. LLM can be provided via `llm`, which can be either the name of the class (case doesn't matter), the path or its suffix to the class (case matters), or a subclass or an instance of `llama_index.core.llms.LLM`. Keyword arguments are passed to the resolved LLM. For defaults, see `chat.completions_configs.llama_index` in `vectorbtpro._settings.knowledge`.""" _short_name = "llama_index" _settings_path: tp.SettingsPath = "knowledge.chat.completions_configs.llama_index" def __init__( self, context: str = "", chat_history: tp.Optional[tp.ChatHistory] = None, stream: tp.Optional[bool] = None, max_tokens: tp.Optional[int] = None, tokenizer: tp.TokenizerLike = None, tokenizer_kwargs: tp.KwargsLike = None, system_prompt: tp.Optional[str] = None, system_as_user: tp.Optional[bool] = None, context_prompt: tp.Optional[str] = None, formatter: tp.ContentFormatterLike = None, formatter_kwargs: tp.KwargsLike = None, silence_warnings: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, llm: tp.Union[None, str, tp.MaybeType[LLMT]] = None, **kwargs, ) -> None: Completions.__init__( self, context=context, chat_history=chat_history, stream=stream, max_tokens=max_tokens, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs, system_prompt=system_prompt, system_as_user=system_as_user, context_prompt=context_prompt, formatter=formatter, formatter_kwargs=formatter_kwargs, silence_warnings=silence_warnings, template_context=template_context, llm=llm, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("llama_index") from llama_index.core.llms import LLM llama_index_config = merge_dicts(self.get_settings(inherit=False), kwargs) def_llm = llama_index_config.pop("llm", None) if llm is None: llm = def_llm if llm is None: raise ValueError("Must provide an LLM name or path") init_kwargs = get_func_kwargs(type(self).__init__) for k in list(llama_index_config.keys()): if k in init_kwargs: llama_index_config.pop(k) if isinstance(llm, str): import llama_index.llms from vectorbtpro.utils.module_ import search_package def _match_func(k, v): if isinstance(v, type) and issubclass(v, LLM): if "." in llm: if k.endswith(llm): return True else: if k.split(".")[-1].lower() == llm.lower(): return True if k.split(".")[-1].replace("LLM", "").lower() == llm.lower().replace("_", ""): return True return False found_llm = search_package( llama_index.llms, _match_func, path_attrs=True, return_first=True, ) if found_llm is None: raise ValueError(f"LLM '{llm}' not found") llm = found_llm if isinstance(llm, type): checks.assert_subclass_of(llm, LLM, arg_name="llm") llm_name = llm.__name__.replace("LLM", "").lower() module_name = llm.__module__ else: checks.assert_instance_of(llm, LLM, arg_name="llm") llm_name = type(llm).__name__.replace("LLM", "").lower() module_name = type(llm).__module__ llm_configs = llama_index_config.pop("llm_configs", {}) if llm_name in llm_configs: llama_index_config = merge_dicts(llama_index_config, llm_configs[llm_name]) elif module_name in llm_configs: llama_index_config = merge_dicts(llama_index_config, llm_configs[module_name]) if isinstance(llm, type): llm = llm(**llama_index_config) elif len(kwargs) > 0: raise ValueError("Cannot apply config to already initialized LLM") model = llama_index_config.get("model", None) if model is None: func_kwargs = get_func_kwargs(type(llm).__init__) model = func_kwargs.get("model", None) self._model = model self._llm = llm @property def model(self) -> tp.Optional[str]: return self._model @property def llm(self) -> LLMT: """LLM.""" return self._llm def get_chat_response(self, messages: tp.ChatMessages) -> ChatResponseT: from llama_index.core.llms import ChatMessage return self.llm.chat(list(map(lambda x: ChatMessage(**dict(x)), messages))) def get_message_content(self, response: ChatResponseT) -> tp.Optional[str]: return response.message.content def get_stream_response(self, messages: tp.ChatMessages) -> tp.Iterator[ChatResponseT]: from llama_index.core.llms import ChatMessage return self.llm.stream_chat(list(map(lambda x: ChatMessage(**dict(x)), messages))) def get_delta_content(self, response_chunk: ChatResponseT) -> tp.Optional[str]: return response_chunk.delta def resolve_completions(completions: tp.CompletionsLike = None) -> tp.MaybeType[Completions]: """Resolve a subclass or an instance of `Completions`. The following values are supported: * "openai" (`OpenAICompletions`) * "litellm" (`LiteLLMCompletions`) * "llama_index" (`LlamaIndexCompletions`) * "auto": Any installed from above, in the same order * A subclass or an instance of `Completions` """ if completions is None: from vectorbtpro._settings import settings chat_cfg = settings["knowledge"]["chat"] completions = chat_cfg["completions"] if isinstance(completions, str): if completions.lower() == "auto": from vectorbtpro.utils.module_ import check_installed if check_installed("openai"): completions = "openai" elif check_installed("litellm"): completions = "litellm" elif check_installed("llama_index"): completions = "llama_index" else: raise ValueError("No packages for completions installed") curr_module = sys.modules[__name__] found_completions = None for name, cls in inspect.getmembers(curr_module, inspect.isclass): if name.endswith("Completions"): _short_name = getattr(cls, "_short_name", None) if _short_name is not None and _short_name.lower() == completions.lower(): found_completions = cls break if found_completions is None: raise ValueError(f"Invalid completions: '{completions}'") completions = found_completions if isinstance(completions, type): checks.assert_subclass_of(completions, Completions, arg_name="completions") else: checks.assert_instance_of(completions, Completions, arg_name="completions") return completions def complete(message: str, completions: tp.CompletionsLike = None, **kwargs) -> tp.ChatOutput: """Get completion for a message. Resolves `completions` with `resolve_completions`. Keyword arguments are passed to either initialize a class or replace an instance of `Completions`.""" completions = resolve_completions(completions=completions) if isinstance(completions, type): completions = completions(**kwargs) elif kwargs: completions = completions.replace(**kwargs) return completions.get_completion(message) # ############# Splitting ############# # class TextSplitter(Configured): """Abstract class for text splitters. For defaults, see `knowledge.chat.text_splitter_config` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = None """Short name of the class.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.text_splitter_config"] def __init__( self, chunk_template: tp.Optional[tp.CustomTemplateLike] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Configured.__init__( self, chunk_template=chunk_template, template_context=template_context, **kwargs, ) chunk_template = self.resolve_setting(chunk_template, "chunk_template") template_context = self.resolve_setting(template_context, "template_context", merge=True) self._chunk_template = chunk_template self._template_context = template_context @property def chunk_template(self) -> tp.Kwargs: """Chunk template. Can use the following context: `chunk_idx`, `chunk_start`, `chunk_end`, `chunk_text`, and `text`.""" return self._chunk_template @property def template_context(self) -> tp.Kwargs: """Context used to substitute templates.""" return self._template_context def split(self, text: str) -> tp.TSRangeChunks: """Split text and yield start character and end character position of each chunk.""" raise NotImplementedError def split_text(self, text: str) -> tp.TSTextChunks: """Split text and return text chunks.""" for chunk_idx, (chunk_start, chunk_end) in enumerate(self.split(text)): chunk_text = text[chunk_start:chunk_end] chunk_template = self.chunk_template if isinstance(chunk_template, str): chunk_template = SafeSub(chunk_template) elif checks.is_function(chunk_template): chunk_template = RepFunc(chunk_template) elif not isinstance(chunk_template, CustomTemplate): raise TypeError(f"Chunk template must be a string, function, or template") template_context = flat_merge_dicts( dict( chunk_idx=chunk_idx, chunk_start=chunk_start, chunk_end=chunk_end, chunk_text=chunk_text, text=text, ), self.template_context, ) yield chunk_template.substitute(template_context, eval_id="chunk_template") class TokenSplitter(TextSplitter): """Splitter class for tokens. For defaults, see `chat.text_splitter_configs.token` in `vectorbtpro._settings.knowledge`.""" _short_name = "token" _settings_path: tp.SettingsPath = "knowledge.chat.text_splitter_configs.token" def __init__( self, chunk_size: tp.Optional[int] = None, chunk_overlap: tp.Union[None, int, float] = None, tokenizer: tp.TokenizerLike = None, tokenizer_kwargs: tp.KwargsLike = None, **kwargs, ) -> None: TextSplitter.__init__( self, chunk_size=chunk_size, chunk_overlap=chunk_overlap, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs, **kwargs, ) chunk_size = self.resolve_setting(chunk_size, "chunk_size") chunk_overlap = self.resolve_setting(chunk_overlap, "chunk_overlap") tokenizer = self.resolve_setting(tokenizer, "tokenizer", default=None) tokenizer_kwargs = self.resolve_setting(tokenizer_kwargs, "tokenizer_kwargs", default=None, merge=True) tokenizer = resolve_tokenizer(tokenizer) if isinstance(tokenizer, type): tokenizer_kwargs = dict(tokenizer_kwargs) tokenizer_kwargs["template_context"] = merge_dicts( self.template_context, tokenizer_kwargs.get("template_context", None) ) tokenizer = tokenizer(**tokenizer_kwargs) elif tokenizer_kwargs: tokenizer = tokenizer.replace(**tokenizer_kwargs) if checks.is_float(chunk_overlap): if 0 <= abs(chunk_overlap) <= 1: chunk_overlap = chunk_overlap * chunk_size elif not chunk_overlap.is_integer(): raise TypeError("Floating number for chunk_overlap must be between 0 and 1") chunk_overlap = int(chunk_overlap) if chunk_overlap >= chunk_size: raise ValueError("Chunk overlap must be less than the chunk size") self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._tokenizer = tokenizer @property def chunk_size(self) -> int: """Maximum number of tokens per chunk.""" return self._chunk_size @property def chunk_overlap(self) -> int: """Number of overlapping tokens between chunks. Can also be provided as a floating number relative to `SegmentSplitter.chunk_size`.""" return self._chunk_overlap @property def tokenizer(self) -> Tokenizer: """An instance of `Tokenizer`.""" return self._tokenizer def split_into_tokens(self, text: str) -> tp.TSRangeChunks: """Split text into tokens.""" tokens = self.tokenizer.encode(text) last_end = 0 for token in tokens: _text = self.tokenizer.decode_single(token) start = last_end end = start + len(_text) yield start, end last_end = end def split(self, text: str) -> tp.TSRangeChunks: tokens = list(self.split_into_tokens(text)) total_tokens = len(tokens) if not tokens: return token_count = 0 while token_count < total_tokens: chunk_tokens = tokens[token_count:token_count + self.chunk_size] chunk_start = chunk_tokens[0][0] chunk_end = chunk_tokens[-1][1] yield chunk_start, chunk_end if token_count + self.chunk_size >= total_tokens: break token_count += self.chunk_size - self.chunk_overlap class SegmentSplitter(TokenSplitter): """Splitter class for segments based on separators. If a segment is too big, the next separator within the same layer is taken to split the segment into smaller segments. If a segment is too big and there are no segments previously added to the chunk, or, if the number of tokens is less than the minimal count, the next layer is taken. To split into tokens, set any separator to None. To split into characters, use an empty string. For defaults, see `chat.text_splitter_configs.segment` in `vectorbtpro._settings.knowledge`.""" _short_name = "segment" _settings_path: tp.SettingsPath = "knowledge.chat.text_splitter_configs.segment" def __init__( self, separators: tp.MaybeList[tp.MaybeList[tp.Optional[str]]] = None, min_chunk_size: tp.Union[None, int, float] = None, fixed_overlap: tp.Optional[bool] = None, **kwargs, ) -> None: TokenSplitter.__init__( self, separators=separators, min_chunk_size=min_chunk_size, fixed_overlap=fixed_overlap, **kwargs, ) separators = self.resolve_setting(separators, "separators") min_chunk_size = self.resolve_setting(min_chunk_size, "min_chunk_size") fixed_overlap = self.resolve_setting(fixed_overlap, "fixed_overlap") if not isinstance(separators, list): separators = [separators] else: separators = list(separators) for layer in range(len(separators)): if not isinstance(separators[layer], list): separators[layer] = [separators[layer]] else: separators[layer] = list(separators[layer]) if checks.is_float(min_chunk_size): if 0 <= abs(min_chunk_size) <= 1: min_chunk_size = min_chunk_size * self.chunk_size elif not min_chunk_size.is_integer(): raise TypeError("Floating number for min_chunk_size must be between 0 and 1") min_chunk_size = int(min_chunk_size) self._separators = separators self._min_chunk_size = min_chunk_size self._fixed_overlap = fixed_overlap @property def separators(self) -> tp.List[tp.List[tp.Optional[str]]]: """Nested list of separators grouped into layers.""" return self._separators @property def min_chunk_size(self) -> int: """Minimum number of tokens per chunk. Can also be provided as a floating number relative to `SegmentSplitter.chunk_size`.""" return self._min_chunk_size @property def fixed_overlap(self) -> bool: """Whether overlap should be fixed.""" return self._fixed_overlap def split_into_segments(self, text: str, separator: tp.Optional[str] = None) -> tp.TSSegmentChunks: """Split text into segments.""" if not separator: if separator is None: for start, end in self.split_into_tokens(text): yield start, end, False else: for i in range(len(text)): yield i, i + 1, False else: last_end = 0 for match in re.finditer(separator, text): start, end = match.span() if start > last_end: _text = text[last_end:start] yield last_end, start, False _text = text[start:end] yield start, end, True last_end = end if last_end < len(text): _text = text[last_end:] yield last_end, len(text), False def split(self, text: str) -> tp.TSRangeChunks: if not text: yield 0, 0 return None total_tokens = self.tokenizer.count_tokens(text) if total_tokens <= self.chunk_size: yield 0, len(text) return None layer = 0 chunk_start = 0 chunk_continue = 0 chunk_tokens = [] stable_token_count = 0 stable_char_count = 0 remaining_text = text overlap_segments = [] token_offset_map = {} while remaining_text: if layer == 0: if chunk_continue: curr_start = chunk_continue else: curr_start = chunk_start curr_text = remaining_text curr_segments = list(overlap_segments) curr_tokens = list(chunk_tokens) curr_stable_token_count = stable_token_count curr_stable_char_count = stable_char_count sep_curr_segments = None sep_curr_tokens = None sep_curr_stable_token_count = None sep_curr_stable_char_count = None for separator in self.separators[layer]: segments = self.split_into_segments(curr_text, separator=separator) curr_text = "" finished = False for segment in segments: segment_start = curr_start + segment[0] segment_end = curr_start + segment[1] segment_is_separator = segment[2] if not curr_tokens: segment_text = text[segment_start:segment_end] new_curr_tokens = self.tokenizer.encode(segment_text) new_curr_stable_token_count = 0 new_curr_stable_char_count = 0 elif not curr_stable_token_count: chunk_text = text[chunk_start:segment_end] new_curr_tokens = self.tokenizer.encode(chunk_text) new_curr_stable_token_count = 0 new_curr_stable_char_count = 0 min_token_count = min(len(curr_tokens), len(new_curr_tokens)) for i in range(min_token_count): if curr_tokens[i] == new_curr_tokens[i]: new_curr_stable_token_count += 1 new_curr_stable_char_count += len(self.tokenizer.decode_single(curr_tokens[i])) else: break else: stable_tokens = curr_tokens[:curr_stable_token_count] unstable_start = chunk_start + curr_stable_char_count partial_text = text[unstable_start:segment_end] partial_tokens = self.tokenizer.encode(partial_text) new_curr_tokens = stable_tokens + partial_tokens new_curr_stable_token_count = curr_stable_token_count new_curr_stable_char_count = curr_stable_char_count min_token_count = min(len(curr_tokens), len(new_curr_tokens)) for i in range(curr_stable_token_count, min_token_count): if curr_tokens[i] == new_curr_tokens[i]: new_curr_stable_token_count += 1 new_curr_stable_char_count += len(self.tokenizer.decode_single(curr_tokens[i])) else: break if len(new_curr_tokens) > self.chunk_size: if segment_is_separator: if ( sep_curr_segments and len(sep_curr_tokens) >= self.min_chunk_size and not (self.chunk_overlap and len(sep_curr_tokens) <= self.chunk_overlap) ): curr_segments = list(sep_curr_segments) curr_tokens = list(sep_curr_tokens) curr_stable_token_count = sep_curr_stable_token_count curr_stable_char_count = sep_curr_stable_char_count segment_start = curr_segments[-1][0] segment_end = curr_segments[-1][1] curr_text = text[segment_start:segment_end] curr_start = segment_start finished = False break else: curr_segments.append((segment_start, segment_end, segment_is_separator)) token_offset_map[segment_start] = len(curr_tokens) curr_tokens = new_curr_tokens curr_stable_token_count = new_curr_stable_token_count curr_stable_char_count = new_curr_stable_char_count if segment_is_separator: sep_curr_segments = list(curr_segments) sep_curr_tokens = list(curr_tokens) sep_curr_stable_token_count = curr_stable_token_count sep_curr_stable_char_count = curr_stable_char_count finished = True if finished: break if ( curr_segments and len(curr_tokens) >= self.min_chunk_size and not (self.chunk_overlap and len(curr_tokens) <= self.chunk_overlap) ): chunk_start = curr_segments[0][0] chunk_end = curr_segments[-1][1] yield chunk_start, chunk_end if chunk_end == len(text): break if self.chunk_overlap: fixed_overlap = True if not self.fixed_overlap: for segment in curr_segments: if not segment[2]: token_offset = token_offset_map[segment[0]] if token_offset > curr_stable_token_count: break if len(curr_tokens) - token_offset <= self.chunk_overlap: chunk_tokens = curr_tokens[token_offset:] new_chunk_start = segment[0] chunk_offset = new_chunk_start - chunk_start chunk_start = new_chunk_start chunk_continue = chunk_end fixed_overlap = False break if fixed_overlap: chunk_tokens = curr_tokens[-self.chunk_overlap:] token_offset = len(curr_tokens) - len(chunk_tokens) new_chunk_start = chunk_end - len(self.tokenizer.decode(chunk_tokens)) chunk_offset = new_chunk_start - chunk_start chunk_start = new_chunk_start chunk_continue = chunk_end stable_token_count = max(0, curr_stable_token_count - token_offset) stable_char_count = max(0, curr_stable_char_count - chunk_offset) overlap_segments = [(chunk_start, chunk_end, False)] token_offset_map[chunk_start] = 0 else: chunk_tokens = [] chunk_start = chunk_end chunk_continue = 0 stable_token_count = 0 stable_char_count = 0 overlap_segments = [] token_offset_map = {} if chunk_continue: remaining_text = text[chunk_continue:] else: remaining_text = text[chunk_start:] layer = 0 else: layer += 1 if layer == len(self.separators): if curr_segments and curr_segments[-1][1] == len(text): chunk_start = curr_segments[0][0] chunk_end = curr_segments[-1][1] yield chunk_start, chunk_end break remaining_tokens = self.tokenizer.encode(remaining_text) if len(remaining_tokens) > self.chunk_size: raise ValueError( "Total number of tokens in the last chunk is greater than the chunk size. " "Increase chunk_size or the separator granularity." ) yield curr_start, len(text) break class LlamaIndexSplitter(TextSplitter): """Splitter class based on a node parser from LlamaIndex. For defaults, see `chat.text_splitter_configs.llama_index` in `vectorbtpro._settings.knowledge`.""" _short_name = "llama_index" _settings_path: tp.SettingsPath = "knowledge.chat.text_splitter_configs.llama_index" def __init__( self, node_parser: tp.Union[None, str, NodeParserT] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: TextSplitter.__init__(self, template_context=template_context, **kwargs) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("llama_index") from llama_index.core.node_parser import NodeParser llama_index_config = merge_dicts(self.get_settings(inherit=False), kwargs) def_node_parser = llama_index_config.pop("node_parser", None) if node_parser is None: node_parser = def_node_parser init_kwargs = get_func_kwargs(type(self).__init__) for k in list(llama_index_config.keys()): if k in init_kwargs: llama_index_config.pop(k) if isinstance(node_parser, str): import llama_index.core.node_parser from vectorbtpro.utils.module_ import search_package def _match_func(k, v): if isinstance(v, type) and issubclass(v, NodeParser): if "." in node_parser: if k.endswith(node_parser): return True else: if k.split(".")[-1].lower() == node_parser.lower(): return True if k.split(".")[-1].replace("Splitter", "").replace( "NodeParser", "" ).lower() == node_parser.lower().replace("_", ""): return True return False found_node_parser = search_package( llama_index.core.node_parser, _match_func, path_attrs=True, return_first=True, ) if found_node_parser is None: raise ValueError(f"Node parser '{node_parser}' not found") node_parser = found_node_parser if isinstance(node_parser, type): checks.assert_subclass_of(node_parser, NodeParser, arg_name="node_parser") node_parser_name = node_parser.__name__.replace("Splitter", "").replace("NodeParser", "").lower() module_name = node_parser.__module__ else: checks.assert_instance_of(node_parser, NodeParser, arg_name="node_parser") node_parser_name = type(node_parser).__name__.replace("Splitter", "").replace("NodeParser", "").lower() module_name = type(node_parser).__module__ node_parser_configs = llama_index_config.pop("node_parser_configs", {}) if node_parser_name in node_parser_configs: llama_index_config = merge_dicts(llama_index_config, node_parser_configs[node_parser_name]) elif module_name in node_parser_configs: llama_index_config = merge_dicts(llama_index_config, node_parser_configs[module_name]) if isinstance(node_parser, type): node_parser = node_parser(**llama_index_config) elif len(kwargs) > 0: raise ValueError("Cannot apply config to already initialized node parser") model_name = llama_index_config.get("model_name", None) if model_name is None: func_kwargs = get_func_kwargs(type(node_parser).__init__) model_name = func_kwargs.get("model_name", None) self._model = model_name self._node_parser = node_parser @property def node_parser(self) -> NodeParserT: """An instance of `llama_index.core.node_parser.interface.NodeParser`.""" return self._node_parser def split(self, text: str) -> tp.TSRangeChunks: for text_chunk in self.split_text(text): start = text.find(text_chunk) if start == -1: end = -1 else: end = start + len(text_chunk) yield start, end def split_text(self, text: str) -> tp.TSTextChunks: from llama_index.core.schema import Document nodes = self.node_parser.get_nodes_from_documents([Document(text=text)]) for node in nodes: yield node.text def resolve_text_splitter(text_splitter: tp.TextSplitterLike = None) -> tp.MaybeType[TextSplitter]: """Resolve a subclass or an instance of `TextSplitter`. The following values are supported: * "token" (`TokenSplitter`) * "segment" (`SegmentSplitter`) * "llama_index" (`LlamaIndexSplitter`) * A subclass or an instance of `TextSplitter` """ if text_splitter is None: from vectorbtpro._settings import settings chat_cfg = settings["knowledge"]["chat"] text_splitter = chat_cfg["text_splitter"] if isinstance(text_splitter, str): curr_module = sys.modules[__name__] found_text_splitter = None for name, cls in inspect.getmembers(curr_module, inspect.isclass): if name.endswith("Splitter"): _short_name = getattr(cls, "_short_name", None) if _short_name is not None and _short_name.lower() == text_splitter.lower(): found_text_splitter = cls break if found_text_splitter is None: raise ValueError(f"Invalid text splitter: '{text_splitter}'") text_splitter = found_text_splitter if isinstance(text_splitter, type): checks.assert_subclass_of(text_splitter, TextSplitter, arg_name="text_splitter") else: checks.assert_instance_of(text_splitter, TextSplitter, arg_name="text_splitter") return text_splitter def split_text(text: str, text_splitter: tp.TextSplitterLike = None, **kwargs) -> tp.List[str]: """Split text. Resolves `text_splitter` with `resolve_text_splitter`. Keyword arguments are passed to either initialize a class or replace an instance of `TextSplitter`.""" text_splitter = resolve_text_splitter(text_splitter=text_splitter) if isinstance(text_splitter, type): text_splitter = text_splitter(**kwargs) elif kwargs: text_splitter = text_splitter.replace(**kwargs) return list(text_splitter.split_text(text)) # ############# Storing ############# # StoreObjectT = tp.TypeVar("StoreObjectT", bound="StoreObject") @define class StoreObject(DefineMixin): """Class for objects to be managed by a store.""" id_: str = define.field() """Object identifier.""" @property def hash_key(self) -> tuple: return (self.id_,) StoreDocumentT = tp.TypeVar("StoreDocumentT", bound="StoreDocument") @define class StoreDocument(StoreObject, DefineMixin): """Abstract class for documents to be stored.""" data: tp.Any = define.field() """Data.""" template_context: tp.KwargsLike = define.field(factory=dict) """Context used to substitute templates.""" @classmethod def id_from_data(cls, data: tp.Any) -> str: """Generate a unique identifier from data.""" from vectorbtpro.utils.pickling import dumps return hashlib.md5(dumps(data)).hexdigest() @classmethod def from_data( cls: tp.Type[StoreDocumentT], data: tp.Any, id_: tp.Optional[str] = None, **kwargs, ) -> StoreDocumentT: """Create an instance of `StoreDocument` from data.""" if id_ is None: id_ = cls.id_from_data(data) return cls(id_, data, **kwargs) def __attrs_post_init__(self): if self.id_ is None: new_id = self.id_from_data(self.data) object.__setattr__(self, "id_", new_id) def get_content(self, for_embed: bool = False) -> tp.Optional[str]: """Get content. Returns None if there's no content.""" raise NotImplementedError def split(self: StoreDocumentT) -> tp.List[StoreDocumentT]: """Split document into multiple documents.""" raise NotImplementedError def __str__(self) -> str: return self.get_content() TextDocumentT = tp.TypeVar("TextDocumentT", bound="TextDocument") def def_metadata_template(metadata_content: str) -> str: """Default metadata template""" if metadata_content.endswith("\n"): return "---\n{metadata_content}---\n\n".format(metadata_content=metadata_content) return "---\n{metadata_content}\n---\n\n".format(metadata_content=metadata_content) @define class TextDocument(StoreDocument, DefineMixin): """Class for text documents.""" text_path: tp.Optional[tp.PathLikeKey] = define.field(default=None) """Path to the text field.""" split_text_kwargs: tp.KwargsLike = define.field(factory=dict) """Keyword arguments passed to `split_text`.""" excl_metadata: tp.Union[bool, tp.MaybeList[tp.PathLikeKey]] = define.field(default=False) """Whether to exclude metadata and which fields to exclude. If False, metadata becomes everything except text.""" excl_embed_metadata: tp.Union[None, bool, tp.MaybeList[tp.PathLikeKey]] = define.field(default=None) """Whether to exclude metadata and which fields to exclude for embeddings. If None, becomes `TextDocument.excl_metadata`.""" skip_missing: bool = define.field(default=True) """Set missing text or metadata to None rather than raise an error.""" dump_kwargs: tp.KwargsLike = define.field(factory=dict) """Keyword arguments passed to `vectorbtpro.utils.formatting.dump`.""" metadata_template: tp.CustomTemplateLike = define.field( default=RepFunc(def_metadata_template, eval_id="metadata_template") ) """Metadata template. Must be suitable for formatting via the `format()` method.""" content_template: tp.CustomTemplateLike = define.field( default=SafeSub("${metadata_content}${text}", eval_id="content_template") ) """Content template. Must be suitable for formatting via the `format()` method.""" def get_text(self) -> tp.Optional[str]: """Get text. Returns None if no text.""" from vectorbtpro.utils.search_ import get_pathlike_key if self.data is None: return None if isinstance(self.data, str): return self.data if self.text_path is not None: try: text = get_pathlike_key(self.data, self.text_path, keep_path=False) except (KeyError, IndexError, AttributeError) as e: if not self.skip_missing: raise e return None if text is None: return None if not isinstance(text, str): raise TypeError(f"Text field must be a string, not {type(text)}") return text raise TypeError(f"If text path is not provided, data item must be a string, not {type(self.data)}") def get_metadata(self, for_embed: bool = False) -> tp.Optional[tp.Any]: """Get metadata. Returns None if no metadata.""" from vectorbtpro.utils.search_ import remove_pathlike_key if self.data is None or isinstance(self.data, str) or self.text_path is None: return None prev_keys = [] data = self.data try: data = remove_pathlike_key(data, self.text_path, make_copy=True, prev_keys=prev_keys) except (KeyError, IndexError, AttributeError) as e: if not self.skip_missing: raise e excl_metadata = self.excl_metadata if for_embed: excl_embed_metadata = self.excl_embed_metadata if excl_embed_metadata is None: excl_embed_metadata = excl_metadata excl_metadata = excl_embed_metadata if isinstance(excl_metadata, bool): if excl_metadata: return None return data if not excl_metadata: return data if not isinstance(excl_metadata, list): excl_metadata = [excl_metadata] for p in excl_metadata: try: data = remove_pathlike_key(data, p, make_copy=True, prev_keys=prev_keys) except (KeyError, IndexError, AttributeError) as e: continue return data def get_metadata_content(self, for_embed: bool = False) -> tp.Optional[str]: """Get metadata content. Returns None if no metadata.""" from vectorbtpro.utils.formatting import dump metadata = self.get_metadata(for_embed=for_embed) if metadata is None: return None return dump(metadata, **self.dump_kwargs) def get_content(self, for_embed: bool = False) -> tp.Optional[str]: text = self.get_text() metadata_content = self.get_metadata_content(for_embed=for_embed) if text is None and metadata_content is None: return None if text is None: text = "" if metadata_content is None: metadata_content = "" if metadata_content: metadata_template = self.metadata_template if isinstance(metadata_template, str): metadata_template = SafeSub(metadata_template) elif checks.is_function(metadata_template): metadata_template = RepFunc(metadata_template) elif not isinstance(metadata_template, CustomTemplate): raise TypeError(f"Metadata template must be a string, function, or template") template_context = flat_merge_dicts( dict(metadata_content=metadata_content), self.template_context, ) metadata_content = metadata_template.substitute(template_context, eval_id="metadata_template") content_template = self.content_template if isinstance(content_template, str): content_template = SafeSub(content_template) elif checks.is_function(content_template): content_template = RepFunc(content_template) elif not isinstance(content_template, CustomTemplate): raise TypeError(f"Content template must be a string, function, or template") template_context = flat_merge_dicts( dict(metadata_content=metadata_content, text=text), self.template_context, ) return content_template.substitute(template_context, eval_id="content_template") def split(self: TextDocumentT) -> tp.List[TextDocumentT]: from vectorbtpro.utils.search_ import set_pathlike_key text = self.get_text() if text is None: return [self] text_chunks = split_text(text, **self.split_text_kwargs) document_chunks = [] for text_chunk in text_chunks: if not isinstance(self.data, str) and self.text_path is not None: data_chunk = set_pathlike_key( self.data, self.text_path, text_chunk, make_copy=True, ) else: data_chunk = text_chunk document_chunks.append(self.replace(data=data_chunk, id_=None)) return document_chunks @define class StoreEmbedding(StoreObject, DefineMixin): """Class for embeddings to be stored.""" parent_id: tp.Optional[str] = define.field(default=None) """Parent object identifier.""" child_ids: tp.List[str] = define.field(factory=list) """Child object identifiers.""" embedding: tp.Optional[tp.List[int]] = define.field(default=None, repr=lambda x: f"List[{len(x)}]" if x else None) """Embedding.""" class MetaObjectStore(type(Configured), type(MutableMapping)): """Metaclass for `ObjectStore`.""" pass class ObjectStore(Configured, MutableMapping, metaclass=MetaObjectStore): """Abstract class for managing an object store. For defaults, see `knowledge.chat.obj_store_config` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = None """Short name of the class.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.obj_store_config"] def __init__( self, store_id: tp.Optional[str] = None, purge_on_open: tp.Optional[bool] = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Configured.__init__( self, store_id=store_id, purge_on_open=purge_on_open, template_context=template_context, **kwargs, ) store_id = self.resolve_setting(store_id, "store_id") purge_on_open = self.resolve_setting(purge_on_open, "purge_on_open") template_context = self.resolve_setting(template_context, "template_context", merge=True) self._store_id = store_id self._purge_on_open = purge_on_open self._template_context = template_context self._opened = False self._enter_calls = 0 @property def store_id(self) -> str: """Store id.""" return self._store_id @property def purge_on_open(self) -> bool: """Whether to purge on open.""" return self._purge_on_open @property def template_context(self) -> tp.Kwargs: """Context used to substitute templates.""" return self._template_context @property def opened(self) -> bool: """Whether the store has been opened.""" return self._opened @property def enter_calls(self) -> int: """Number of enter calls.""" return self._enter_calls @property def mirror_store_id(self) -> tp.Optional[str]: """Mirror store id.""" return None def open(self) -> None: """Open the store.""" if self.opened: self.close() if self.purge_on_open: self.purge() self._opened = True def check_opened(self) -> None: """Check the store is opened.""" if not self.opened: raise Exception(f"{type(self)} must be opened first") def commit(self) -> None: """Commit changes.""" pass def close(self) -> None: """Close the store.""" self.commit() self._opened = False def purge(self) -> None: """Purge the store.""" self.close() def __getitem__(self, id_: str) -> StoreObjectT: raise NotImplementedError def __setitem__(self, id_: str, obj: StoreObjectT) -> None: raise NotImplementedError def __delitem__(self, id_: str) -> None: raise NotImplementedError def __iter__(self) -> tp.Iterator[str]: raise NotImplementedError def __len__(self) -> int: raise NotImplementedError def __enter__(self) -> tp.Self: if not self.opened: self.open() self._enter_calls += 1 return self def __exit__(self, *args) -> None: if self.enter_calls == 1: self.close() self._close_on_exit = False self._enter_calls -= 1 if self.enter_calls < 0: self._enter_calls = 0 class DictStore(ObjectStore): """Store class based on a dictionary. For defaults, see `chat.obj_store_configs.memory` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "dict" _settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.dict" def __init__(self, **kwargs) -> None: ObjectStore.__init__(self, **kwargs) self._store = {} @property def store(self) -> tp.Dict[str, StoreObjectT]: """Store dictionary.""" return self._store def purge(self) -> None: ObjectStore.purge(self) self.store.clear() def __getitem__(self, id_: str) -> StoreObjectT: return self.store[id_] def __setitem__(self, id_: str, obj: StoreObjectT) -> None: self.store[id_] = obj def __delitem__(self, id_: str) -> None: del self.store[id_] def __iter__(self) -> tp.Iterator[str]: return iter(self.store) def __len__(self) -> int: return len(self.store) memory_store: tp.Dict[str, tp.Dict[str, StoreObjectT]] = {} """Object store by store id for `MemoryStore`.""" class MemoryStore(DictStore): """Store class based in memory. Commits changes to `memory_store`. For defaults, see `chat.obj_store_configs.memory` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "memory" _settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.memory" def __init__(self, **kwargs) -> None: DictStore.__init__(self, **kwargs) @property def store(self) -> tp.Dict[str, StoreObjectT]: """Store dictionary.""" return self._store def store_exists(self) -> bool: """Whether store exists.""" return self.store_id in memory_store def open(self) -> None: DictStore.open(self) if self.store_exists(): self._store = dict(memory_store[self.store_id]) def commit(self) -> None: DictStore.commit(self) memory_store[self.store_id] = dict(self.store) def purge(self) -> None: DictStore.purge(self) if self.store_exists(): del memory_store[self.store_id] class FileStore(DictStore): """Store class based on files. Either commits changes to a single file (with index id being the file name), or commits the initial changes to the base file and any other change to patch file(s) (with index id being the directory name). For defaults, see `chat.obj_store_configs.file` in `vectorbtpro._settings.knowledge`.""" _short_name = "file" _settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.file" def __init__( self, dir_path: tp.Optional[tp.PathLike] = None, compression: tp.Union[None, bool, str] = None, save_kwargs: tp.KwargsLike = None, load_kwargs: tp.KwargsLike = None, use_patching: tp.Optional[bool] = None, consolidate: tp.Optional[bool] = None, **kwargs, ) -> None: DictStore.__init__( self, dir_path=dir_path, compression=compression, save_kwargs=save_kwargs, load_kwargs=load_kwargs, use_patching=use_patching, consolidate=consolidate, **kwargs, ) dir_path = self.resolve_setting(dir_path, "dir_path") template_context = self.template_context if isinstance(dir_path, CustomTemplate): cache_dir = self.get_setting("cache_dir", default=None) if cache_dir is not None: if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) release_dir = self.get_setting("release_dir", default=None) if release_dir is not None: if isinstance(release_dir, CustomTemplate): release_dir = release_dir.substitute(template_context, eval_id="release_dir") template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context) dir_path = dir_path.substitute(template_context, eval_id="dir_path") compression = self.resolve_setting(compression, "compression") save_kwargs = self.resolve_setting(save_kwargs, "save_kwargs", merge=True) load_kwargs = self.resolve_setting(load_kwargs, "load_kwargs", merge=True) use_patching = self.resolve_setting(use_patching, "use_patching") consolidate = self.resolve_setting(consolidate, "consolidate") self._dir_path = dir_path self._compression = compression self._save_kwargs = save_kwargs self._load_kwargs = load_kwargs self._use_patching = use_patching self._consolidate = consolidate self._store_changes = {} self._new_keys = set() @property def dir_path(self) -> tp.Optional[tp.Path]: """Path to the directory.""" return self._dir_path @property def compression(self) -> tp.CompressionLike: """Compression.""" return self._compression @property def save_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pickling.save`.""" return self._save_kwargs @property def load_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pickling.load`.""" return self._load_kwargs @property def use_patching(self) -> bool: """Whether to use directory with patch files or create a single file.""" return self._use_patching @property def consolidate(self) -> bool: """Whether to consolidate patch files.""" return self._consolidate @property def store_changes(self) -> tp.Dict[str, StoreObjectT]: """Store with new or modified objects only.""" return self._store_changes @property def new_keys(self) -> tp.Set[str]: """Keys that haven't been added to the store.""" return self._new_keys def reset_state(self) -> None: """Reset state.""" self._consolidate = False self._store_changes = {} self._new_keys = set() @property def store_path(self) -> tp.Path: """Path to the directory with patch files or a single file.""" dir_path = self.dir_path if dir_path is None: dir_path = "." dir_path = Path(dir_path) return dir_path / self.store_id @property def mirror_store_id(self) -> str: return str(self.store_path.resolve()) def get_next_patch_path(self) -> tp.Path: """Get path to the next patch file to be saved.""" indices = [] for file in self.store_path.glob("patch_*"): indices.append(int(file.stem.split("_")[1])) next_index = max(indices) + 1 if indices else 0 return self.store_path / f"patch_{next_index}" def open(self) -> None: DictStore.open(self) if self.store_path.exists(): from vectorbtpro.utils.pickling import load if self.store_path.is_dir(): store = {} store.update( load( path=self.store_path / "base", compression=self.compression, **self.load_kwargs, ) ) patch_paths = sorted(self.store_path.glob("patch_*"), key=lambda f: int(f.stem.split("_")[1])) for patch_path in patch_paths: store.update( load( path=patch_path, compression=self.compression, **self.load_kwargs, ) ) else: store = load( path=self.store_path, compression=self.compression, **self.load_kwargs, ) self._store = store self.reset_state() def commit(self) -> tp.Optional[tp.Path]: DictStore.commit(self) from vectorbtpro.utils.pickling import save file_path = None if self.use_patching: base_path = self.store_path / "base" if self.consolidate: self.purge() file_path = save( self.store, path=base_path, compression=self.compression, **self.save_kwargs, ) elif self.store_changes: if self.store_path.exists() and self.store_path.is_file(): self.purge() if not base_path.exists(): file_path = save( self.store_changes, path=base_path, compression=self.compression, **self.save_kwargs, ) else: file_path = save( self.store_changes, path=self.get_next_patch_path(), compression=self.compression, **self.save_kwargs, ) else: if self.consolidate or self.store_changes: if self.store_path.exists() and self.store_path.is_dir(): self.purge() file_path = save( self.store, path=self.store_path, compression=self.compression, **self.save_kwargs, ) self.reset_state() return file_path def close(self) -> None: DictStore.close(self) self.reset_state() def purge(self) -> None: DictStore.purge(self) from vectorbtpro.utils.path_ import remove_file, remove_dir if self.store_path.exists(): if self.store_path.is_dir(): remove_dir(self.store_path, with_contents=True) else: remove_file(self.store_path) self.reset_state() def __setitem__(self, id_: str, obj: StoreObjectT) -> None: if obj.id_ not in self: self.new_keys.add(obj.id_) self.store_changes[obj.id_] = obj DictStore.__setitem__(self, id_, obj) def __delitem__(self, id_: str) -> None: if id_ in self.new_keys: del self.store_changes[id_] self.new_keys.remove(id_) else: if id_ in self.store_changes: del self.store_changes[id_] DictStore.__delitem__(self, id_) class LMDBStore(ObjectStore): """Store class based on LMDB (Lightning Memory-Mapped Database). Uses [lmdbm](https://pypi.org/project/lmdbm/) package. For defaults, see `chat.obj_store_configs.lmdb` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "lmdb" _expected_keys_mode: tp.ExpectedKeysMode = "disable" _settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.lmdb" def __init__( self, dir_path: tp.Optional[tp.PathLike] = None, mkdir_kwargs: tp.KwargsLike = None, dumps_kwargs: tp.KwargsLike = None, loads_kwargs: tp.KwargsLike = None, **kwargs, ) -> None: ObjectStore.__init__( self, dir_path=dir_path, mkdir_kwargs=mkdir_kwargs, dumps_kwargs=dumps_kwargs, loads_kwargs=loads_kwargs, **kwargs, ) from vectorbtpro.utils.module_ import assert_can_import assert_can_import("lmdbm") dir_path = self.resolve_setting(dir_path, "dir_path") template_context = self.template_context if isinstance(dir_path, CustomTemplate): cache_dir = self.get_setting("cache_dir", default=None) if cache_dir is not None: if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) release_dir = self.get_setting("release_dir", default=None) if release_dir is not None: if isinstance(release_dir, CustomTemplate): release_dir = release_dir.substitute(template_context, eval_id="release_dir") template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context) dir_path = dir_path.substitute(template_context, eval_id="dir_path") mkdir_kwargs = self.resolve_setting(mkdir_kwargs, "mkdir_kwargs", merge=True) dumps_kwargs = self.resolve_setting(dumps_kwargs, "dumps_kwargs", merge=True) loads_kwargs = self.resolve_setting(loads_kwargs, "loads_kwargs", merge=True) open_kwargs = merge_dicts(self.get_settings(inherit=False), kwargs) for arg_name in get_func_arg_names(ObjectStore.__init__) + get_func_arg_names(type(self).__init__): if arg_name in open_kwargs: del open_kwargs[arg_name] if "mirror" in open_kwargs: del open_kwargs["mirror"] self._dir_path = dir_path self._mkdir_kwargs = mkdir_kwargs self._dumps_kwargs = dumps_kwargs self._loads_kwargs = loads_kwargs self._open_kwargs = open_kwargs self._db = None @property def dir_path(self) -> tp.Optional[tp.Path]: """Path to the directory.""" return self._dir_path @property def mkdir_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.path_.check_mkdir`.""" return self._mkdir_kwargs @property def dumps_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pickling.dumps`.""" return self._dumps_kwargs @property def loads_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pickling.loads`.""" return self._loads_kwargs @property def open_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `lmdbm.lmdbm.Lmdb.open`.""" return self._open_kwargs @property def db_path(self) -> tp.Path: """Path to the database.""" dir_path = self.dir_path if dir_path is None: dir_path = "." dir_path = Path(dir_path) return dir_path / self.store_id @property def mirror_store_id(self) -> str: return str(self.db_path.resolve()) @property def db(self) -> tp.Optional[LmdbT]: """Database.""" return self._db def open(self) -> None: ObjectStore.open(self) from lmdbm import Lmdb from vectorbtpro.utils.path_ import check_mkdir check_mkdir(self.db_path.parent, **self.mkdir_kwargs) self._db = Lmdb.open(str(self.db_path.resolve()), **self.open_kwargs) def close(self) -> None: ObjectStore.close(self) if self.db: self.db.close() self._db = None def purge(self) -> None: ObjectStore.purge(self) from vectorbtpro.utils.path_ import remove_dir remove_dir(self.db_path, missing_ok=True, with_contents=True) def encode(self, obj: StoreObjectT) -> bytes: """Encode an object.""" from vectorbtpro.utils.pickling import dumps return dumps(obj, **self.dumps_kwargs) def decode(self, bytes_: bytes) -> StoreObjectT: """Decode an object.""" from vectorbtpro.utils.pickling import loads return loads(bytes_, **self.loads_kwargs) def __getitem__(self, id_: str) -> StoreObjectT: self.check_opened() return self.decode(self.db[id_]) def __setitem__(self, id_: str, obj: StoreObjectT) -> None: self.check_opened() self.db[id_] = self.encode(obj) def __delitem__(self, id_: str) -> None: self.check_opened() del self.db[id_] def __iter__(self) -> tp.Iterator[str]: self.check_opened() return iter(self.db) def __len__(self) -> int: self.check_opened() return len(self.db) class CachedStore(DictStore): """Store class that acts as a (temporary) cache to another store. For defaults, see `chat.obj_store_configs.cached` in `vectorbtpro._settings.knowledge`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "cached" _settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.cached" def __init__( self, obj_store: ObjectStore, lazy_open: tp.Optional[bool] = None, mirror: tp.Optional[bool] = None, **kwargs, ) -> None: DictStore.__init__( self, obj_store=obj_store, lazy_open=lazy_open, mirror=mirror, **kwargs, ) lazy_open = self.resolve_setting(lazy_open, "lazy_open") mirror = obj_store.resolve_setting(mirror, "mirror", default=None) mirror = self.resolve_setting(mirror, "mirror") if mirror and obj_store.mirror_store_id is None: mirror = False self._obj_store = obj_store self._lazy_open = lazy_open self._mirror = mirror self._force_open = False @property def obj_store(self) -> ObjectStore: """Object store.""" return self._obj_store @property def lazy_open(self) -> bool: """Whether to open the store lazily.""" return self._lazy_open @property def mirror(self) -> bool: """Whether to mirror the store in `memory_store`.""" return self._mirror @property def force_open(self) -> bool: """Whether to open the store forcefully.""" return self._force_open def open(self) -> None: DictStore.open(self) if self.mirror and self.obj_store.mirror_store_id in memory_store: self.store.update(memory_store[self.obj_store.mirror_store_id]) elif not self.lazy_open or self.force_open: self.obj_store.open() def check_opened(self) -> None: if self.lazy_open and not self.obj_store.opened: self._force_open = True self.obj_store.open() DictStore.check_opened(self) def commit(self) -> None: DictStore.commit(self) self.check_opened() self.obj_store.commit() if self.mirror: memory_store[self.obj_store.mirror_store_id] = dict(self.store) def close(self) -> None: DictStore.close(self) self.obj_store.close() self._force_open = False def purge(self) -> None: DictStore.purge(self) self.obj_store.purge() if self.mirror and self.obj_store.mirror_store_id in memory_store: del memory_store[self.obj_store.mirror_store_id] def __getitem__(self, id_: str) -> StoreObjectT: if id_ in self.store: return self.store[id_] self.check_opened() obj = self.obj_store[id_] self.store[id_] = obj return obj def __setitem__(self, id_: str, obj: StoreObjectT) -> None: self.check_opened() self.store[id_] = obj self.obj_store[id_] = obj def __delitem__(self, id_: str) -> None: self.check_opened() if id_ in self.store: del self.store[id_] del self.obj_store[id_] def __iter__(self) -> tp.Iterator[str]: self.check_opened() return iter(self.obj_store) def __len__(self) -> int: self.check_opened() return len(self.obj_store) def resolve_obj_store(obj_store: tp.ObjectStoreLike = None) -> tp.MaybeType[ObjectStore]: """Resolve a subclass or an instance of `ObjectStore`. The following values are supported: * "dict" (`DictStore`) * "memory" (`MemoryStore`) * "file" (`FileStore`) * "lmdb" (`LMDBStore`) * "cached" (`CachedStore`) * A subclass or an instance of `ObjectStore` """ if obj_store is None: from vectorbtpro._settings import settings chat_cfg = settings["knowledge"]["chat"] obj_store = chat_cfg["obj_store"] if isinstance(obj_store, str): curr_module = sys.modules[__name__] found_obj_store = None for name, cls in inspect.getmembers(curr_module, inspect.isclass): if name.endswith("Store"): _short_name = getattr(cls, "_short_name", None) if _short_name is not None and _short_name.lower() == obj_store.lower(): found_obj_store = cls break if found_obj_store is None: raise ValueError(f"Invalid object store: '{obj_store}'") obj_store = found_obj_store if isinstance(obj_store, type): checks.assert_subclass_of(obj_store, ObjectStore, arg_name="obj_store") else: checks.assert_instance_of(obj_store, ObjectStore, arg_name="obj_store") return obj_store # ############# Ranking ############# # @define class EmbeddedDocument(DefineMixin): """Abstract class for embedded documents.""" document: StoreDocument = define.field() """Document.""" embedding: tp.Optional[tp.List[float]] = define.field(default=None) """Embedding.""" child_documents: tp.List["EmbeddedDocument"] = define.field(factory=list) """Embedded child documents.""" @define class ScoredDocument(DefineMixin): """Abstract class for scored documents.""" document: StoreDocument = define.field() """Document.""" score: float = define.field(default=float("nan")) """Score.""" child_documents: tp.List["ScoredDocument"] = define.field(factory=list) """Scored child documents.""" class DocumentRanker(Configured): """Class for embedding, scoring, and ranking documents. For defaults, see `knowledge.chat.doc_ranker_config` in `vectorbtpro._settings.knowledge`.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.doc_ranker_config"] def __init__( self, dataset_id: tp.Optional[str] = None, embeddings: tp.EmbeddingsLike = None, embeddings_kwargs: tp.KwargsLike = None, doc_store: tp.TokenizerLike = None, doc_store_kwargs: tp.KwargsLike = None, cache_doc_store: tp.Optional[bool] = None, emb_store: tp.TokenizerLike = None, emb_store_kwargs: tp.KwargsLike = None, cache_emb_store: tp.Optional[bool] = None, score_func: tp.Union[None, str, tp.Callable] = None, score_agg_func: tp.Union[None, str, tp.Callable] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: Configured.__init__( self, dataset_id=dataset_id, embeddings=embeddings, embeddings_kwargs=embeddings_kwargs, doc_store=doc_store, doc_store_kwargs=doc_store_kwargs, cache_doc_store=cache_doc_store, emb_store=emb_store, emb_store_kwargs=emb_store_kwargs, cache_emb_store=cache_emb_store, score_func=score_func, score_agg_func=score_agg_func, show_progress=show_progress, pbar_kwargs=pbar_kwargs, template_context=template_context, **kwargs, ) dataset_id = self.resolve_setting(dataset_id, "dataset_id") embeddings = self.resolve_setting(embeddings, "embeddings", default=None) embeddings_kwargs = self.resolve_setting(embeddings_kwargs, "embeddings_kwargs", default=None, merge=True) doc_store = self.resolve_setting(doc_store, "doc_store", default=None) doc_store_kwargs = self.resolve_setting(doc_store_kwargs, "doc_store_kwargs", default=None, merge=True) cache_doc_store = self.resolve_setting(cache_doc_store, "cache_doc_store") emb_store = self.resolve_setting(emb_store, "emb_store", default=None) emb_store_kwargs = self.resolve_setting(emb_store_kwargs, "emb_store_kwargs", default=None, merge=True) cache_emb_store = self.resolve_setting(cache_emb_store, "cache_emb_store") score_func = self.resolve_setting(score_func, "score_func") score_agg_func = self.resolve_setting(score_agg_func, "score_agg_func") show_progress = self.resolve_setting(show_progress, "show_progress") pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True) template_context = self.resolve_setting(template_context, "template_context", merge=True) obj_store = self.get_setting("obj_store", default=None) obj_store_kwargs = self.get_setting("obj_store_kwargs", default=None, merge=True) if doc_store is None: doc_store = obj_store doc_store_kwargs = merge_dicts(obj_store_kwargs, doc_store_kwargs) if emb_store is None: emb_store = obj_store emb_store_kwargs = merge_dicts(obj_store_kwargs, emb_store_kwargs) embeddings = resolve_embeddings(embeddings) if isinstance(embeddings, type): embeddings_kwargs = dict(embeddings_kwargs) embeddings_kwargs["template_context"] = merge_dicts( template_context, embeddings_kwargs.get("template_context", None) ) embeddings = embeddings(**embeddings_kwargs) elif embeddings_kwargs: embeddings = embeddings.replace(**embeddings_kwargs) if isinstance(self._settings_path, list): if not isinstance(self._settings_path[-1], str): raise TypeError("_settings_path[-1] for DocumentRanker and its subclasses must be a string") target_settings_path = self._settings_path[-1] elif isinstance(self._settings_path, str): target_settings_path = self._settings_path else: raise TypeError("_settings_path for DocumentRanker and its subclasses must be a list or string") doc_store = resolve_obj_store(doc_store) if not isinstance(doc_store._settings_path, str): raise TypeError("_settings_path for ObjectStore and its subclasses must be a string") doc_store_cls = doc_store if isinstance(doc_store, type) else type(doc_store) doc_store_settings_path = doc_store._settings_path doc_store_settings_path = doc_store_settings_path.replace("knowledge.chat", target_settings_path) doc_store_settings_path = doc_store_settings_path.replace("obj_store", "doc_store") with ExtSettingsPath([(doc_store_cls, doc_store_settings_path)]): if isinstance(doc_store, type): doc_store_kwargs = dict(doc_store_kwargs) if dataset_id is not None and "store_id" not in doc_store_kwargs: doc_store_kwargs["store_id"] = dataset_id doc_store_kwargs["template_context"] = merge_dicts( template_context, doc_store_kwargs.get("template_context", None) ) doc_store = doc_store(**doc_store_kwargs) elif doc_store_kwargs: doc_store = doc_store.replace(**doc_store_kwargs) if cache_doc_store and not isinstance(doc_store, CachedStore): doc_store = CachedStore(doc_store) emb_store = resolve_obj_store(emb_store) if not isinstance(emb_store._settings_path, str): raise TypeError("_settings_path for ObjectStore and its subclasses must be a string") emb_store_cls = emb_store if isinstance(emb_store, type) else type(emb_store) emb_store_settings_path = emb_store._settings_path emb_store_settings_path = emb_store_settings_path.replace("knowledge.chat", target_settings_path) emb_store_settings_path = emb_store_settings_path.replace("obj_store", "emb_store") with ExtSettingsPath([(emb_store_cls, emb_store_settings_path)]): if isinstance(emb_store, type): emb_store_kwargs = dict(emb_store_kwargs) if dataset_id is not None and "store_id" not in emb_store_kwargs: emb_store_kwargs["store_id"] = dataset_id emb_store_kwargs["template_context"] = merge_dicts( template_context, emb_store_kwargs.get("template_context", None) ) emb_store = emb_store(**emb_store_kwargs) elif emb_store_kwargs: emb_store = emb_store.replace(**emb_store_kwargs) if cache_emb_store and not isinstance(emb_store, CachedStore): emb_store = CachedStore(emb_store) if isinstance(score_agg_func, str): score_agg_func = getattr(np, score_agg_func) self._embeddings = embeddings self._doc_store = doc_store self._emb_store = emb_store self._score_func = score_func self._score_agg_func = score_agg_func self._show_progress = show_progress self._pbar_kwargs = pbar_kwargs self._template_context = template_context @property def embeddings(self) -> Embeddings: """An instance of `Embeddings`.""" return self._embeddings @property def doc_store(self) -> ObjectStore: """An instance of `ObjectStore` for documents.""" return self._doc_store @property def emb_store(self) -> ObjectStore: """An instance of `ObjectStore` for embeddings.""" return self._emb_store @property def score_func(self) -> tp.Union[str, tp.Callable]: """Score function. See `DocumentRanker.compute_score`.""" return self._score_func @property def score_agg_func(self) -> tp.Callable: """Score aggregation function.""" return self._score_agg_func @property def show_progress(self) -> tp.Optional[bool]: """Whether to show progress bar.""" return self._show_progress @property def pbar_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.""" return self._pbar_kwargs @property def template_context(self) -> tp.Kwargs: """Context used to substitute templates.""" return self._template_context def embed_documents( self, documents: tp.Iterable[StoreDocument], refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_embeddings: bool = False, return_documents: bool = False, ) -> tp.Optional[tp.EmbeddedDocuments]: """Embed documents. Enable `refresh` or its sub-arguments to refresh documents and/or embeddings in their particular stores. Without refreshing, will rely on the persisted objects. If `return_embeddings` and `return_documents` are both False, returns nothing. If `return_embeddings` and `return_documents` are both True, for each document, returns the document and either an embedding or a list of document chunks and their embeddings. If `return_documents` is False, returns only embeddings.""" if refresh_documents is None: refresh_documents = refresh if refresh_embeddings is None: refresh_embeddings = refresh with self.doc_store, self.emb_store: documents = list(documents) documents_to_split = [] document_splits = {} for document in documents: if refresh_documents or refresh_embeddings or document.id_ not in self.emb_store: documents_to_split.append(document) if documents_to_split: from vectorbtpro.utils.pbar import ProgressBar pbar_kwargs = merge_dicts(dict(prefix="split_documents"), self.pbar_kwargs) with ProgressBar( total=len(documents_to_split), show_progress=self.show_progress, **pbar_kwargs, ) as pbar: for document in documents_to_split: document_splits[document.id_] = document.split() pbar.update() obj_contents = {} for document in documents: if refresh_documents or document.id_ not in self.doc_store: self.doc_store[document.id_] = document if document.id_ in document_splits: document_chunks = document_splits[document.id_] obj = StoreEmbedding(document.id_) for document_chunk in document_chunks: if document_chunk.id_ != document.id_: if refresh_documents or document_chunk.id_ not in self.doc_store: self.doc_store[document_chunk.id_] = document_chunk if refresh_embeddings or document_chunk.id_ not in self.emb_store: child_obj = StoreEmbedding(document_chunk.id_, parent_id=document.id_) self.emb_store[child_obj.id_] = child_obj else: child_obj = self.emb_store[document_chunk.id_] obj.child_ids.append(child_obj.id_) if not child_obj.embedding: content = document_chunk.get_content(for_embed=True) if content: obj_contents[child_obj.id_] = content if refresh_documents or refresh_embeddings or document.id_ not in self.emb_store: self.emb_store[obj.id_] = obj else: obj = self.emb_store[document.id_] if not obj.child_ids and not obj.embedding: content = document.get_content(for_embed=True) if content: obj_contents[obj.id_] = content if obj_contents: total = 0 for batch in self.embeddings.iter_embedding_batches(list(obj_contents.values())): batch_keys = list(obj_contents.keys())[total : total + len(batch)] obj_embeddings = dict(zip(batch_keys, batch)) for obj_id, embedding in obj_embeddings.items(): obj = self.emb_store[obj_id] new_obj = obj.replace(embedding=embedding) self.emb_store[new_obj.id_] = new_obj total += len(batch) if return_embeddings or return_documents: embeddings = [] for document in documents: obj = self.emb_store[document.id_] if obj.embedding: if return_documents: embeddings.append(EmbeddedDocument(document, embedding=obj.embedding)) else: embeddings.append(obj.embedding) elif obj.child_ids: child_embeddings = [] for child_id in obj.child_ids: child_obj = self.emb_store[child_id] if child_obj.embedding: if return_documents: child_document = self.doc_store[child_id] child_embeddings.append( EmbeddedDocument(child_document, embedding=child_obj.embedding) ) else: child_embeddings.append(child_obj.embedding) else: if return_documents: child_document = self.doc_store[child_id] child_embeddings.append(EmbeddedDocument(child_document)) else: child_embeddings.append(None) if return_documents: embeddings.append(EmbeddedDocument(document, child_documents=child_embeddings)) else: embeddings.append(child_embeddings) else: if return_documents: embeddings.append(EmbeddedDocument(document)) else: embeddings.append(None) return embeddings def compute_score( self, emb1: tp.Union[tp.MaybeIterable[tp.List[float]], np.ndarray], emb2: tp.Union[tp.MaybeIterable[tp.List[float]], np.ndarray], ) -> tp.Union[float, np.ndarray]: """Compute scores between embeddings, which can be either single or multiple. Supported distance functions are 'cosine', 'euclidean', and 'dot'. A metric can also be a callable that should take two and return one 2-dim NumPy array.""" emb1 = np.asarray(emb1) emb2 = np.asarray(emb2) emb1_single = emb1.ndim == 1 emb2_single = emb2.ndim == 1 if emb1_single: emb1 = emb1.reshape(1, -1) if emb2_single: emb2 = emb2.reshape(1, -1) if isinstance(self.score_func, str): if self.score_func.lower() == "cosine": emb1_norm = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True) emb2_norm = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True) emb1_norm = np.nan_to_num(emb1_norm) emb2_norm = np.nan_to_num(emb2_norm) score_matrix = np.dot(emb1_norm, emb2_norm.T) elif self.score_func.lower() == "euclidean": diff = emb1[:, np.newaxis, :] - emb2[np.newaxis, :, :] distances = np.linalg.norm(diff, axis=2) score_matrix = np.divide(1, distances, where=distances != 0, out=np.full_like(distances, np.inf)) elif self.score_func.lower() == "dot": score_matrix = np.dot(emb1, emb2.T) else: raise ValueError(f"Invalid distance function: '{self.score_func}'") else: score_matrix = self.score_func(emb1, emb2) if emb1_single and emb2_single: return float(score_matrix[0, 0]) if emb1_single or emb2_single: return score_matrix.flatten() return score_matrix def score_documents( self, query: str, documents: tp.Optional[tp.Iterable[StoreDocument]] = None, refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_chunks: bool = False, return_documents: bool = False, ) -> tp.ScoredDocuments: """Score documents by relevance to a query.""" with self.doc_store, self.emb_store: if documents is None: if self.doc_store is None: raise ValueError("Must provide at least documents or doc_store") documents = self.doc_store.values() documents_provided = False else: documents_provided = True documents = list(documents) if not documents: return [] self.embed_documents( documents, refresh=refresh, refresh_documents=refresh_documents, refresh_embeddings=refresh_embeddings, ) if return_chunks: document_chunks = [] for document in documents: obj = self.emb_store[document.id_] if obj.child_ids: for child_id in obj.child_ids: document_chunk = self.doc_store[child_id] document_chunks.append(document_chunk) elif not obj.parent_id or obj.parent_id not in self.doc_store: document_chunk = self.doc_store[obj.id_] document_chunks.append(document_chunk) documents = document_chunks elif not documents_provided: document_parents = [] for document in documents: obj = self.emb_store[document.id_] if not obj.parent_id or obj.parent_id not in self.doc_store: document_parent = self.doc_store[obj.id_] document_parents.append(document_parent) documents = document_parents obj_embeddings = {} for document in documents: obj = self.emb_store[document.id_] if obj.embedding: obj_embeddings[obj.id_] = obj.embedding elif obj.child_ids: for child_id in obj.child_ids: child_obj = self.emb_store[child_id] if child_obj.embedding: obj_embeddings[child_id] = child_obj.embedding if obj_embeddings: query_embedding = self.embeddings.get_embedding(query) scores = self.compute_score(query_embedding, list(obj_embeddings.values())) obj_scores = dict(zip(obj_embeddings.keys(), scores)) else: obj_scores = {} scores = [] for document in documents: obj = self.emb_store[document.id_] child_scores = [] if obj.child_ids: for child_id in obj.child_ids: if child_id in obj_scores: child_score = obj_scores[child_id] if return_documents: child_document = self.doc_store[child_id] child_scores.append(ScoredDocument(child_document, score=child_score)) else: child_scores.append(child_score) if child_scores: if return_documents: doc_score = self.score_agg_func([document.score for document in child_scores]) else: doc_score = self.score_agg_func(child_scores) else: doc_score = float("nan") else: if obj.id_ in obj_scores: doc_score = obj_scores[obj.id_] else: doc_score = float("nan") if return_documents: scores.append(ScoredDocument(document, score=doc_score, child_documents=child_scores)) else: scores.append(doc_score) return scores @classmethod def resolve_top_k(cls, scores: tp.Iterable[float], top_k: tp.TopKLike = None) -> tp.Optional[int]: """Resolve `top_k` based on _sorted_ scores. Supported values are integers (top number), floats (top %), strings (supported methods are 'elbow' and 'kmeans'), as well as callables that should take a 1-dim NumPy array and return an integer or a float. Filters out NaN before computation (requires them to be at the tail).""" if top_k is None: return None scores = np.asarray(scores) scores = scores[~np.isnan(scores)] if isinstance(top_k, str): if top_k.lower() == "elbow": if scores.size == 0: return 0 diffs = np.diff(scores) top_k = np.argmax(-diffs) + 1 elif top_k.lower() == "kmeans": from sklearn.cluster import KMeans kmeans = KMeans(n_clusters=2, random_state=0).fit(scores.reshape(-1, 1)) high_score_cluster = np.argmax(kmeans.cluster_centers_) top_k_indices = np.where(kmeans.labels_ == high_score_cluster)[0] top_k = max(top_k_indices) + 1 else: raise ValueError(f"Invalid top_k method: '{top_k}'") elif callable(top_k): top_k = top_k(scores) if checks.is_float(top_k): top_k = int(top_k * len(scores)) return top_k @classmethod def top_k_from_cutoff(cls, scores: tp.Iterable[float], cutoff: tp.Optional[float] = None) -> tp.Optional[int]: """Get `top_k` from `cutoff` based on _sorted_ scores.""" if cutoff is None: return None scores = np.asarray(scores) scores = scores[~np.isnan(scores)] return len(scores[scores >= cutoff]) def rank_documents( self, query: str, documents: tp.Optional[tp.Iterable[StoreDocument]] = None, top_k: tp.TopKLike = None, min_top_k: tp.TopKLike = None, max_top_k: tp.TopKLike = None, cutoff: tp.Optional[float] = None, refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_chunks: bool = False, return_scores: bool = False, ) -> tp.RankedDocuments: """Sort documents by relevance to a query. Top-k, minimum top-k, and maximum top-k are resolved with `DocumentRanker.resolve_top_k`. Score cutoff is converted into top-k with `DocumentRanker.top_k_from_cutoff`. Minimum and maximum top-k are used to override non-integer top-k and cutoff; it has no effect on the integer top-k, which can be outside the top-k bounds and won't be overridden.""" scored_documents = self.score_documents( query, documents=documents, refresh=refresh, refresh_documents=refresh_documents, refresh_embeddings=refresh_embeddings, return_chunks=return_chunks, return_documents=True, ) scored_documents = sorted(scored_documents, key=lambda x: (not np.isnan(x.score), x.score), reverse=True) scores = [document.score for document in scored_documents] int_top_k = top_k is not None and checks.is_int(top_k) top_k = self.resolve_top_k(scores, top_k=top_k) min_top_k = self.resolve_top_k(scores, top_k=min_top_k) max_top_k = self.resolve_top_k(scores, top_k=max_top_k) cutoff = self.top_k_from_cutoff(scores, cutoff=cutoff) if not int_top_k and min_top_k is not None and min_top_k > top_k: top_k = min_top_k if not int_top_k and max_top_k is not None and max_top_k < top_k: top_k = max_top_k if cutoff is not None and min_top_k is not None and min_top_k > cutoff: cutoff = min_top_k if cutoff is not None and max_top_k is not None and max_top_k < cutoff: cutoff = max_top_k if top_k is None: top_k = len(scores) if cutoff is None: cutoff = len(scores) top_k = min(top_k, cutoff) if top_k == 0: raise ValueError("No documents selected after ranking. Change top_k or cutoff.") scored_documents = scored_documents[:top_k] if return_scores: return scored_documents return [document.document for document in scored_documents] def embed_documents( documents: tp.Iterable[StoreDocument], refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_embeddings: bool = False, return_documents: bool = False, doc_ranker: tp.Optional[tp.MaybeType[DocumentRanker]] = None, **kwargs, ) -> tp.Optional[tp.EmbeddedDocuments]: """Embed documents. Keyword arguments are passed to either initialize a class or replace an instance of `DocumentRanker`.""" if doc_ranker is None: doc_ranker = DocumentRanker if isinstance(doc_ranker, type): checks.assert_subclass_of(doc_ranker, DocumentRanker, "doc_ranker") doc_ranker = doc_ranker(**kwargs) else: checks.assert_instance_of(doc_ranker, DocumentRanker, "doc_ranker") if kwargs: doc_ranker = doc_ranker.replace(**kwargs) return doc_ranker.embed_documents( documents, refresh=refresh, refresh_documents=refresh_documents, refresh_embeddings=refresh_embeddings, return_embeddings=return_embeddings, return_documents=return_documents, ) def rank_documents( query: str, documents: tp.Optional[tp.Iterable[StoreDocument]] = None, top_k: tp.TopKLike = None, min_top_k: tp.TopKLike = None, max_top_k: tp.TopKLike = None, cutoff: tp.Optional[float] = None, refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_chunks: bool = False, return_scores: bool = False, doc_ranker: tp.Optional[tp.MaybeType[DocumentRanker]] = None, **kwargs, ) -> tp.RankedDocuments: """Rank documents by their relevance to a query. Keyword arguments are passed to either initialize a class or replace an instance of `DocumentRanker`.""" if doc_ranker is None: doc_ranker = DocumentRanker if isinstance(doc_ranker, type): checks.assert_subclass_of(doc_ranker, DocumentRanker, "doc_ranker") doc_ranker = doc_ranker(**kwargs) else: checks.assert_instance_of(doc_ranker, DocumentRanker, "doc_ranker") if kwargs: doc_ranker = doc_ranker.replace(**kwargs) return doc_ranker.rank_documents( query, documents=documents, top_k=top_k, min_top_k=min_top_k, max_top_k=max_top_k, cutoff=cutoff, refresh=refresh, refresh_documents=refresh_documents, refresh_embeddings=refresh_embeddings, return_chunks=return_chunks, return_scores=return_scores, ) RankableT = tp.TypeVar("RankableT", bound="Rankable") class Rankable(HasSettings): """Abstract class that can be ranked.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat"] def embed( self: RankableT, refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_embeddings: bool = False, return_documents: bool = False, **kwargs, ) -> tp.Optional[RankableT]: """Embed documents.""" raise NotImplementedError def rank( self: RankableT, query: str, top_k: tp.TopKLike = None, min_top_k: tp.TopKLike = None, max_top_k: tp.TopKLike = None, cutoff: tp.Optional[float] = None, refresh: bool = False, refresh_documents: tp.Optional[bool] = None, refresh_embeddings: tp.Optional[bool] = None, return_chunks: bool = False, return_scores: bool = False, **kwargs, ) -> RankableT: """Rank documents by their relevance to a query.""" raise NotImplementedError # ############# Contexting ############# # class Contextable(HasSettings): """Abstract class that can be converted into a context.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat"] def to_context(self, *args, **kwargs) -> str: """Convert to a context.""" raise NotImplementedError def count_tokens( self, to_context_kwargs: tp.KwargsLike = None, tokenizer: tp.TokenizerLike = None, tokenizer_kwargs: tp.KwargsLike = None, ) -> int: """Count the number of tokens in the context.""" to_context_kwargs = self.resolve_setting(to_context_kwargs, "to_context_kwargs", merge=True) tokenizer = self.resolve_setting(tokenizer, "tokenizer", default=None) tokenizer_kwargs = self.resolve_setting(tokenizer_kwargs, "tokenizer_kwargs", default=None, merge=True) context = self.to_context(**to_context_kwargs) tokenizer = resolve_tokenizer(tokenizer) if isinstance(tokenizer, type): tokenizer = tokenizer(**tokenizer_kwargs) elif tokenizer_kwargs: tokenizer = tokenizer.replace(**tokenizer_kwargs) return len(tokenizer.encode(context)) def create_chat( self, to_context_kwargs: tp.KwargsLike = None, completions: tp.CompletionsLike = None, **kwargs, ) -> tp.Completions: """Create a chat by returning an instance of `Completions`. Uses `Contextable.to_context` to turn this instance to a context. Usage: ```pycon >>> chat = asset.create_chat() >>> chat.get_completion("What's the value under 'xyz'?") The value under 'xyz' is 123. >>> chat.get_completion("Are you sure?") Yes, I am sure. The value under 'xyz' is 123 for the entry where `s` is "EFG". ```""" to_context_kwargs = self.resolve_setting(to_context_kwargs, "to_context_kwargs", merge=True) context = self.to_context(**to_context_kwargs) completions = resolve_completions(completions=completions) if isinstance(completions, type): completions = completions(context=context, **kwargs) else: completions = completions.replace(context=context, **kwargs) return completions @hybrid_method def chat( cls_or_self, message: str, chat_history: tp.Optional[tp.ChatHistory] = None, *, return_chat: bool = False, **kwargs, ) -> tp.MaybeChatOutput: """Chat with an LLM while using the instance as a context. Uses `Contextable.create_chat` and then `Completions.get_completion`. !!! note Context is recalculated each time this method is invoked. For multiple turns, it's more efficient to use `Contextable.create_chat`. Usage: ```pycon >>> asset.chat("What's the value under 'xyz'?") The value under 'xyz' is 123. >>> chat_history = [] >>> asset.chat("What's the value under 'xyz'?", chat_history=chat_history) The value under 'xyz' is 123. >>> asset.chat("Are you sure?", chat_history=chat_history) Yes, I am sure. The value under 'xyz' is 123 for the entry where `s` is "EFG". ``` """ if isinstance(cls_or_self, type): args, kwargs = get_forward_args(super().chat, locals()) return super().chat(*args, **kwargs) completions = cls_or_self.create_chat(chat_history=chat_history, **kwargs) if return_chat: return completions.get_completion(message), completions return completions.get_completion(message) class RankContextable(Rankable, Contextable): """Abstract class that combines both `Rankable` and `Contextable` to rank a context.""" @hybrid_method def chat( cls_or_self, message: str, chat_history: tp.Optional[tp.ChatHistory] = None, *, incl_past_queries: tp.Optional[bool] = None, rank: tp.Optional[bool] = None, top_k: tp.TopKLike = None, min_top_k: tp.TopKLike = None, max_top_k: tp.TopKLike = None, cutoff: tp.Optional[float] = None, return_chunks: tp.Optional[bool] = None, rank_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeChatOutput: """See `Contextable.chat`. If `rank` is True, or `rank` is None and any of `top_k`, `min_top_k`, `max_top_k`, `cutoff`, or `return_chunks` is set, will rank the documents with `Rankable.rank` first.""" if isinstance(cls_or_self, type): args, kwargs = get_forward_args(super().chat, locals()) return super().chat(*args, **kwargs) incl_past_queries = cls_or_self.resolve_setting(incl_past_queries, "incl_past_queries") rank = cls_or_self.resolve_setting(rank, "rank") rank_kwargs = cls_or_self.resolve_setting(rank_kwargs, "rank_kwargs", merge=True) def_top_k = rank_kwargs.pop("top_k") if top_k is None: top_k = def_top_k def_min_top_k = rank_kwargs.pop("min_top_k") if min_top_k is None: min_top_k = def_min_top_k def_max_top_k = rank_kwargs.pop("max_top_k") if max_top_k is None: max_top_k = def_max_top_k def_cutoff = rank_kwargs.pop("cutoff") if cutoff is None: cutoff = def_cutoff def_return_chunks = rank_kwargs.pop("return_chunks") if return_chunks is None: return_chunks = def_return_chunks if rank or (rank is None and (top_k or min_top_k or max_top_k or cutoff or return_chunks)): if incl_past_queries and chat_history is not None: queries = [] for message_dct in chat_history: if "role" in message_dct and message_dct["role"] == "user": queries.append(message_dct["content"]) queries.append(message) if len(queries) > 1: query = "\n\n".join(queries) else: query = queries[0] else: query = message _cls_or_self = cls_or_self.rank( query, top_k=top_k, min_top_k=min_top_k, max_top_k=max_top_k, cutoff=cutoff, return_chunks=return_chunks, **rank_kwargs, ) else: _cls_or_self = cls_or_self return Contextable.chat.__func__(_cls_or_self, message, chat_history, **kwargs) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom asset function classes.""" from vectorbtpro import _typing as tp from vectorbtpro.utils.config import flat_merge_dicts from vectorbtpro.utils.knowledge.base_asset_funcs import AssetFunc, RemoveAssetFunc from vectorbtpro.utils.knowledge.formatting import to_markdown, to_html, format_html __all__ = [] class ToMarkdownAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.custom_assets.VBTAsset.to_markdown`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "to_markdown" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, root_metadata_key: tp.Optional[tp.Key] = None, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **to_markdown_kwargs, ) -> tp.ArgsKwargs: from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc if asset_cls is None: from vectorbtpro.utils.knowledge.custom_assets import VBTAsset asset_cls = VBTAsset root_metadata_key = asset_cls.resolve_setting(root_metadata_key, "root_metadata_key") minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata") minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys") clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata") clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True) dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True) to_markdown_kwargs = asset_cls.resolve_setting(to_markdown_kwargs, "to_markdown_kwargs", merge=True) clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs) _, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs) _, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs) return (), { **dict( minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, root_metadata_key=root_metadata_key, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, ), **to_markdown_kwargs, } @classmethod def get_markdown_metadata( cls, d: dict, root_metadata_key: tp.Optional[tp.Key] = None, allow_empty: tp.Optional[bool] = None, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, **to_markdown_kwargs, ) -> str: """Get metadata in Markdown format.""" from vectorbtpro.utils.formatting import get_dump_language from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc if allow_empty is None: allow_empty = root_metadata_key is not None if clean_metadata_kwargs is None: clean_metadata_kwargs = {} if dump_metadata_kwargs is None: dump_metadata_kwargs = {} metadata = dict(d) if "content" in metadata: del metadata["content"] if metadata and minimize_metadata and minimize_keys: metadata = RemoveAssetFunc.call(metadata, minimize_keys, skip_missing=True) if metadata and clean_metadata: metadata = FindRemoveAssetFunc.call(metadata, **clean_metadata_kwargs) if not metadata and not allow_empty: return "" if root_metadata_key is not None: if not metadata: metadata = None metadata = {root_metadata_key: metadata} text = DumpAssetFunc.call(metadata, **dump_metadata_kwargs).strip() dump_engine = dump_metadata_kwargs.get("dump_engine", "nestedtext") dump_language = get_dump_language(dump_engine) text = f"```{dump_language}\n{text}\n```" return to_markdown(text, **to_markdown_kwargs) @classmethod def get_markdown_content(cls, d: dict, **kwargs) -> str: """Get content in Markdown format.""" if d["content"] is None: return "" return to_markdown(d["content"], **kwargs) @classmethod def call( cls, d: tp.Any, root_metadata_key: tp.Optional[tp.Key] = None, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, **to_markdown_kwargs, ) -> tp.Any: if not isinstance(d, (str, dict)): raise TypeError("Data item must be a string or dict") if isinstance(d, str): d = dict(content=d) markdown_metadata = cls.get_markdown_metadata( d, root_metadata_key=root_metadata_key, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **to_markdown_kwargs, ) markdown_content = cls.get_markdown_content(d, **to_markdown_kwargs) if markdown_metadata and markdown_content: markdown_content = markdown_metadata + "\n\n" + markdown_content elif markdown_metadata: markdown_content = markdown_metadata return markdown_content class ToHTMLAssetFunc(ToMarkdownAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.custom_assets.VBTAsset.to_html`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "to_html" @classmethod def prepare( cls, root_metadata_key: tp.Optional[tp.Key] = None, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, format_html_kwargs: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **to_html_kwargs, ) -> tp.ArgsKwargs: from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc if asset_cls is None: from vectorbtpro.utils.knowledge.custom_assets import VBTAsset asset_cls = VBTAsset root_metadata_key = asset_cls.resolve_setting(root_metadata_key, "root_metadata_key") minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata") minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys") clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata") clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True) dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True) to_markdown_kwargs = asset_cls.resolve_setting(to_markdown_kwargs, "to_markdown_kwargs", merge=True) format_html_kwargs = asset_cls.resolve_setting(format_html_kwargs, "format_html_kwargs", merge=True) to_html_kwargs = asset_cls.resolve_setting(to_html_kwargs, "to_html_kwargs", merge=True) clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs) _, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs) _, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs) return (), { **dict( root_metadata_key=root_metadata_key, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, format_html_kwargs=format_html_kwargs, ), **to_html_kwargs, } @classmethod def get_html_metadata( cls, d: dict, root_metadata_key: tp.Optional[tp.Key] = None, allow_empty: tp.Optional[bool] = None, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, **to_html_kwargs, ) -> str: """Get metadata in HTML format.""" if to_markdown_kwargs is None: to_markdown_kwargs = {} metadata = cls.get_markdown_metadata( d, root_metadata_key=root_metadata_key, allow_empty=allow_empty, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **to_markdown_kwargs, ) if not metadata: return "" return to_html(metadata, **to_html_kwargs) @classmethod def get_html_content(cls, d: dict, to_markdown_kwargs: tp.KwargsLike = None, **kwargs) -> str: """Get content in HTML format.""" if to_markdown_kwargs is None: to_markdown_kwargs = {} content = cls.get_markdown_content(d, **to_markdown_kwargs) return to_html(content, **kwargs) @classmethod def call( cls, d: tp.Any, root_metadata_key: tp.Optional[tp.Key] = None, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, format_html_kwargs: tp.KwargsLike = None, **to_html_kwargs, ) -> tp.Any: if not isinstance(d, (str, dict, list)): raise TypeError("Data item must be a string, dict, or list of such") if isinstance(d, str): d = dict(content=d) if isinstance(d, list): html_metadata = [] for _d in d: if not isinstance(_d, (str, dict)): raise TypeError("Data item must be a string, dict, or list of such") if isinstance(_d, str): _d = dict(content=_d) html_metadata.append( cls.get_html_metadata( _d, root_metadata_key=root_metadata_key, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, **to_html_kwargs, ) ) html = format_html( title="/", html_metadata="\n".join(html_metadata), **format_html_kwargs, ) else: html_metadata = cls.get_html_metadata( d, root_metadata_key=root_metadata_key, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, **to_html_kwargs, ) html_content = cls.get_html_content( d, to_markdown_kwargs=to_markdown_kwargs, **to_html_kwargs, ) html = format_html( title=d["link"] if "link" in d else "", html_metadata=html_metadata, html_content=html_content, **format_html_kwargs, ) return html class AggMessageAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_messages`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "agg_message" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc if asset_cls is None: from vectorbtpro.utils.knowledge.custom_assets import MessagesAsset asset_cls = MessagesAsset minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata") minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys") clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata") clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True) dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True) clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs) _, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs) _, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs) return (), { **dict( minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, ), **kwargs, } @classmethod def call( cls, d: tp.Any, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, ) -> tp.Any: if not isinstance(d, dict): raise TypeError("Data item must be a dict") if "attachments" not in d: return dict(d) if clean_metadata_kwargs is None: clean_metadata_kwargs = {} if dump_metadata_kwargs is None: dump_metadata_kwargs = {} if to_markdown_kwargs is None: to_markdown_kwargs = {} new_d = dict(d) new_d["content"] = new_d["content"].strip() attachments = new_d.pop("attachments", []) for attachment in attachments: content = attachment["content"].strip() if new_d["content"]: new_d["content"] += "\n\n" metadata = ToMarkdownAssetFunc.get_markdown_metadata( attachment, root_metadata_key="attachment", allow_empty=not content, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **to_markdown_kwargs, ) new_d["content"] += metadata if content: new_d["content"] += "\n\n" + content return new_d class AggBlockAssetFunc(AssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_blocks`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "agg_block" _wrap: tp.ClassVar[tp.Optional[str]] = True @classmethod def prepare( cls, aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None, parent_links_only: tp.Optional[bool] = None, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, link_map: tp.Optional[tp.Dict[str, dict]] = None, asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None, **kwargs, ) -> tp.ArgsKwargs: from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc if asset_cls is None: from vectorbtpro.utils.knowledge.custom_assets import MessagesAsset asset_cls = MessagesAsset aggregate_fields = asset_cls.resolve_setting(aggregate_fields, "aggregate_fields") parent_links_only = asset_cls.resolve_setting(parent_links_only, "parent_links_only") minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata") minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys") clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata") clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True) dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True) clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs) _, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs) _, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs) return (), { **dict( aggregate_fields=aggregate_fields, parent_links_only=parent_links_only, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, link_map=link_map, ), **kwargs, } @classmethod def call( cls, d: tp.Any, aggregate_fields: tp.Union[bool, tp.MaybeIterable[str]] = False, parent_links_only: bool = True, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, link_map: tp.Optional[tp.Dict[str, dict]] = None, ) -> tp.Any: if not isinstance(d, dict): raise TypeError("Data item must be a dict") if isinstance(aggregate_fields, bool): if aggregate_fields: aggregate_fields = {"mentions", "attachments", "reactions"} else: aggregate_fields = set() elif isinstance(aggregate_fields, str): aggregate_fields = {aggregate_fields} elif not isinstance(aggregate_fields, set): aggregate_fields = set(aggregate_fields) if clean_metadata_kwargs is None: clean_metadata_kwargs = {} if dump_metadata_kwargs is None: dump_metadata_kwargs = {} if to_markdown_kwargs is None: to_markdown_kwargs = {} new_d = {} metadata_keys = [] for k, v in d.items(): if k == "link": new_d[k] = d["block"][0] if k == "block": continue if k == "timestamp": new_d[k] = v[0] if k in {"thread", "channel", "author"}: new_d[k] = v[0] continue if k == "reference" and link_map is not None: found_missing = False new_v = [] for _v in v: if _v: if _v in link_map: _v = link_map[_v]["block"] else: found_missing = True break if _v not in new_v: new_v.append(_v) if found_missing or len(new_v) > 1: new_d[k] = "?" else: new_d[k] = new_v[0] if k == "replies" and link_map is not None: new_v = [] for _v in v: for __v in _v: if __v and __v in link_map: __v = link_map[__v]["block"] if __v not in new_v: new_v.append(__v) else: new_v.append("?") new_d[k] = new_v if k == "content": new_d[k] = [] continue if k in aggregate_fields and isinstance(v[0], list): new_v = [] for _v in new_v: for __v in _v: if __v not in new_v: new_v.append(__v) new_d[k] = new_v continue if k == "reactions" and k in aggregate_fields: new_d[k] = sum(v) continue if parent_links_only: if k in ("link", "block", "thread", "reference", "replies"): continue metadata_keys.append(k) if len(metadata_keys) > 0: for i in range(len(d[metadata_keys[0]])): content = d["content"][i].strip() metadata = {} for k in metadata_keys: metadata[k] = d[k][i] if len(new_d["content"]) > 0: new_d["content"].append("\n\n") metadata = ToMarkdownAssetFunc.get_markdown_metadata( metadata, root_metadata_key="message", allow_empty=not content, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **to_markdown_kwargs, ) new_d["content"].append(metadata) if content: new_d["content"].append("\n\n" + content) new_d["content"] = "".join(new_d["content"]) return new_d class AggThreadAssetFunc(AggBlockAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_threads`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "agg_thread" @classmethod def call( cls, d: tp.Any, aggregate_fields: tp.Union[bool, tp.MaybeIterable[str]] = False, parent_links_only: bool = True, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, link_map: tp.Optional[tp.Dict[str, dict]] = None, ) -> tp.Any: if not isinstance(d, dict): raise TypeError("Data item must be a dict") if isinstance(aggregate_fields, bool): if aggregate_fields: aggregate_fields = {"mentions", "attachments", "reactions"} else: aggregate_fields = set() elif isinstance(aggregate_fields, str): aggregate_fields = {aggregate_fields} elif not isinstance(aggregate_fields, set): aggregate_fields = set(aggregate_fields) if clean_metadata_kwargs is None: clean_metadata_kwargs = {} if dump_metadata_kwargs is None: dump_metadata_kwargs = {} if to_markdown_kwargs is None: to_markdown_kwargs = {} new_d = {} metadata_keys = [] for k, v in d.items(): if k == "link": new_d[k] = d["thread"][0] if k == "thread": continue if k == "timestamp": new_d[k] = v[0] if k == "channel": new_d[k] = v[0] continue if k == "content": new_d[k] = [] continue if k in aggregate_fields and isinstance(v[0], list): new_v = [] for _v in new_v: for __v in _v: if __v not in new_v: new_v.append(__v) new_d[k] = new_v continue if k == "reactions" and k in aggregate_fields: new_d[k] = sum(v) continue if parent_links_only: if k in ("link", "block", "thread", "reference", "replies"): continue metadata_keys.append(k) if len(metadata_keys) > 0: for i in range(len(d[metadata_keys[0]])): content = d["content"][i].strip() metadata = {} for k in metadata_keys: metadata[k] = d[k][i] if len(new_d["content"]) > 0: new_d["content"].append("\n\n") metadata = ToMarkdownAssetFunc.get_markdown_metadata( metadata, root_metadata_key="message", allow_empty=not content, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **to_markdown_kwargs, ) new_d["content"].append(metadata) if content: new_d["content"].append("\n\n" + content) new_d["content"] = "".join(new_d["content"]) return new_d class AggChannelAssetFunc(AggThreadAssetFunc): """Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_channels`.""" _short_name: tp.ClassVar[tp.Optional[str]] = "agg_channel" @classmethod def get_channel_link(cls, link: str) -> str: """Get channel link from a message link.""" if link.startswith("$discord/"): link = link[len("$discord/") :] link_parts = link.split("/") channel_id = link_parts[0] return "$discord/" + channel_id if link.startswith("https://discord.com/channels/"): link = link[len("https://discord.com/channels/") :] link_parts = link.split("/") guild_id = link_parts[0] channel_id = link_parts[1] return f"https://discord.com/channels/{guild_id}/{channel_id}" raise ValueError(f"Invalid link: '{link}'") @classmethod def call( cls, d: tp.Any, aggregate_fields: tp.Union[bool, tp.MaybeIterable[str]] = False, parent_links_only: bool = True, minimize_metadata: bool = False, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: bool = True, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, link_map: tp.Optional[tp.Dict[str, dict]] = None, ) -> tp.Any: if not isinstance(d, dict): raise TypeError("Data item must be a dict") if isinstance(aggregate_fields, bool): if aggregate_fields: aggregate_fields = {"mentions", "attachments", "reactions"} else: aggregate_fields = set() elif isinstance(aggregate_fields, str): aggregate_fields = {aggregate_fields} elif not isinstance(aggregate_fields, set): aggregate_fields = set(aggregate_fields) if clean_metadata_kwargs is None: clean_metadata_kwargs = {} if dump_metadata_kwargs is None: dump_metadata_kwargs = {} if to_markdown_kwargs is None: to_markdown_kwargs = {} new_d = {} metadata_keys = [] for k, v in d.items(): if k == "link": new_d[k] = cls.get_channel_link(v[0]) if k == "timestamp": new_d[k] = v[0] if k == "channel": new_d[k] = v[0] continue if k == "content": new_d[k] = [] continue if k in aggregate_fields and isinstance(v[0], list): new_v = [] for _v in new_v: for __v in _v: if __v not in new_v: new_v.append(__v) new_d[k] = new_v continue if k == "reactions" and k in aggregate_fields: new_d[k] = sum(v) continue if parent_links_only: if k in ("link", "block", "thread", "reference", "replies"): continue metadata_keys.append(k) if len(metadata_keys) > 0: for i in range(len(d[metadata_keys[0]])): content = d["content"][i].strip() metadata = {} for k in metadata_keys: metadata[k] = d[k][i] if len(new_d["content"]) > 0: new_d["content"].append("\n\n") metadata = ToMarkdownAssetFunc.get_markdown_metadata( metadata, root_metadata_key="message", allow_empty=not content, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **to_markdown_kwargs, ) new_d["content"].append(metadata) if content: new_d["content"].append("\n\n" + content) new_d["content"] = "".join(new_d["content"]) return new_d # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Custom asset classes.""" import inspect import io import os import pkgutil import re import base64 from collections import defaultdict, deque from pathlib import Path from types import ModuleType from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts, reorder_list, HybridConfig, SpecSettingsPath from vectorbtpro.utils.decorators import hybrid_method from vectorbtpro.utils.knowledge.base_assets import AssetCacheManager, KnowledgeAsset from vectorbtpro.utils.knowledge.formatting import FormatHTML from vectorbtpro.utils.module_ import prepare_refname, get_caller_qualname from vectorbtpro.utils.parsing import get_func_arg_names from vectorbtpro.utils.path_ import check_mkdir, remove_dir, get_common_prefix, dir_tree_from_paths from vectorbtpro.utils.pbar import ProgressBar from vectorbtpro.utils.pickling import suggest_compression from vectorbtpro.utils.search_ import find, replace from vectorbtpro.utils.template import CustomTemplate from vectorbtpro.utils.warnings_ import warn __all__ = [ "VBTAsset", "PagesAsset", "MessagesAsset", "find_api", "find_docs", "find_messages", "find_examples", "find_assets", "chat_about", "search", "chat", ] __pdoc__ = {} class_abbr_config = HybridConfig( dict( Accessor={"acc"}, Array={"arr"}, ArrayWrapper={"wrapper"}, Benchmark={"bm"}, Cacheable={"ca"}, Chunkable={"ch"}, Drawdowns={"dd"}, Jitable={"jit"}, Figure={"fig"}, MappedArray={"ma"}, NumPy={"np"}, Numba={"nb"}, Optimizer={"opt"}, Pandas={"pd"}, Portfolio={"pf"}, ProgressBar={"pbar"}, Registry={"reg"}, Returns_={"ret"}, Returns={"rets"}, QuantStats={"qs"}, Signals_={"sig"}, ) ) """_""" __pdoc__[ "class_abbr_config" ] = f"""Config for class name (part) abbreviations. ```python {class_abbr_config.prettify()} ``` """ class NoItemFoundError(Exception): """Exception raised when no data item was found.""" class MultipleItemsFoundError(Exception): """Exception raised when multiple data items were found.""" VBTAssetT = tp.TypeVar("VBTAssetT", bound="VBTAsset") class VBTAsset(KnowledgeAsset): """Class for working with VBT content. For defaults, see `assets.vbt` in `vectorbtpro._settings.knowledge`.""" _settings_path: tp.SettingsPath = "knowledge.assets.vbt" def __init__(self, *args, release_name: tp.Optional[str] = None, **kwargs) -> None: KnowledgeAsset.__init__(self, *args, release_name=release_name, **kwargs) self._release_name = release_name @property def release_name(self) -> tp.Optional[str]: """Release name.""" return self._release_name @classmethod def pull( cls: tp.Type[VBTAssetT], release_name: tp.Optional[str] = None, asset_name: tp.Optional[str] = None, repo_owner: tp.Optional[str] = None, repo_name: tp.Optional[str] = None, token: tp.Optional[str] = None, token_required: tp.Optional[bool] = None, use_pygithub: tp.Optional[bool] = None, chunk_size: tp.Optional[int] = None, cache: tp.Optional[bool] = None, cache_dir: tp.Optional[tp.PathLike] = None, cache_mkdir_kwargs: tp.KwargsLike = None, clear_cache: tp.Optional[bool] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> VBTAssetT: """Build `VBTAsset` from a JSON asset of a release. Examples of a release name include None or 'current' for the current release, 'latest' for the latest release, and any other tag name such as 'v2024.12.15'. An example of an asset file name is 'messages.json.zip'. You can find all asset file names at https://github.com/polakowo/vectorbt.pro/releases/latest Token must be a valid GitHub token. It doesn't have to be provided if the asset has already been downloaded. If `use_pygithub` is True, uses https://github.com/PyGithub/PyGithub (otherwise requests) Argument `chunk_size` denotes the number of bytes in each chunk when downloading an asset file. If `cache` is True, uses the cache directory (`assets_dir` in settings). Otherwise, builds the asset instance in memory. If `clear_cache` is True, deletes any existing directory before creating a new one.""" import requests release_name = cls.resolve_setting(release_name, "release_name") asset_name = cls.resolve_setting(asset_name, "asset_name") repo_owner = cls.resolve_setting(repo_owner, "repo_owner") repo_name = cls.resolve_setting(repo_name, "repo_name") token = cls.resolve_setting(token, "token") token_required = cls.resolve_setting(token_required, "token_required") use_pygithub = cls.resolve_setting(use_pygithub, "use_pygithub") chunk_size = cls.resolve_setting(chunk_size, "chunk_size") cache = cls.resolve_setting(cache, "cache") assets_dir = cls.resolve_setting(cache_dir, "assets_dir") cache_mkdir_kwargs = cls.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True) clear_cache = cls.resolve_setting(clear_cache, "clear_cache") show_progress = cls.resolve_setting(show_progress, "show_progress") pbar_kwargs = cls.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True) template_context = cls.resolve_setting(template_context, "template_context", merge=True) if release_name is None or release_name.lower() == "current": from vectorbtpro._version import __release__ release_name = __release__ if release_name.lower() == "latest": if token is None: token = os.environ.get("GITHUB_TOKEN", None) if token is None and token_required: raise ValueError("GitHub token is required") if use_pygithub is None: from vectorbtpro.utils.module_ import check_installed use_pygithub = check_installed("github") if use_pygithub: from vectorbtpro.utils.module_ import assert_can_import assert_can_import("github") from github import Github, Auth from github.GithubException import UnknownObjectException if token is not None: g = Github(auth=Auth.Token(token)) else: g = Github() try: repo = g.get_repo(f"{repo_owner}/{repo_name}") except UnknownObjectException: raise Exception(f"Repository '{repo_owner}/{repo_name}' not found or access denied") try: release = repo.get_latest_release() except UnknownObjectException: raise Exception("Latest release not found") release_name = release.title else: headers = {"Accept": "application/vnd.github+json"} if token is not None: headers["Authorization"] = f"token {token}" release_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" response = requests.get(release_url, headers=headers) response.raise_for_status() release_info = response.json() release_name = release_info.get("name") template_context = flat_merge_dicts(dict(release_name=release_name), template_context) if isinstance(assets_dir, CustomTemplate): cache_dir = cls.get_setting("cache_dir") if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) release_dir = cls.get_setting("release_dir") if isinstance(release_dir, CustomTemplate): release_dir = release_dir.substitute(template_context, eval_id="release_dir") template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context) assets_dir = assets_dir.substitute(template_context, eval_id="assets_dir") if cache: if assets_dir.exists(): if clear_cache: remove_dir(assets_dir, missing_ok=True, with_contents=True) else: cache_file = None for file in assets_dir.iterdir(): if file.is_file() and file.name == asset_name: cache_file = file break if cache_file is not None: return cls.from_json_file(cache_file, release_name=release_name, **kwargs) if token is None: token = os.environ.get("GITHUB_TOKEN", None) if token is None and token_required: raise ValueError("GitHub token is required") if use_pygithub is None: from vectorbtpro.utils.module_ import check_installed use_pygithub = check_installed("github") if use_pygithub: from vectorbtpro.utils.module_ import assert_can_import assert_can_import("github") from github import Github, Auth from github.GithubException import UnknownObjectException if token is not None: g = Github(auth=Auth.Token(token)) else: g = Github() try: repo = g.get_repo(f"{repo_owner}/{repo_name}") except UnknownObjectException: raise Exception(f"Repository '{repo_owner}/{repo_name}' not found or access denied") releases = repo.get_releases() found_release = None for release in releases: if release.title == release_name: found_release = release if found_release is None: raise Exception(f"Release '{release_name}' not found") release = found_release assets = release.get_assets() if asset_name is not None: asset = next((a for a in assets if a.name == asset_name), None) if asset is None: raise Exception(f"Asset '{asset_name}' not found in release {release_name}") else: assets_list = list(assets) if len(assets_list) == 1: asset = assets_list[0] else: raise Exception("Please specify asset_name") asset_url = asset.url else: headers = {"Accept": "application/vnd.github+json"} if token is not None: headers["Authorization"] = f"token {token}" releases_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases" response = requests.get(releases_url, headers=headers) response.raise_for_status() releases = response.json() release_info = None for release in releases: if release.get("name") == release_name: release_info = release if release_info is None: raise ValueError(f"Release '{release_name}' not found") assets = release_info.get("assets", []) if asset_name is not None: asset = next((a for a in assets if a["name"] == asset_name), None) if asset is None: raise Exception(f"Asset '{asset_name}' not found in release {release_name}") else: if len(assets) == 1: asset = assets[0] else: raise Exception("Please specify asset_name") asset_url = asset["url"] asset_headers = {"Accept": "application/octet-stream"} if token is not None: asset_headers["Authorization"] = f"token {token}" asset_response = requests.get(asset_url, headers=asset_headers, stream=True) asset_response.raise_for_status() file_size = int(asset_response.headers.get("Content-Length", 0)) if file_size == 0: file_size = asset.get("size", 0) if show_progress is None: show_progress = True pbar_kwargs = flat_merge_dicts( dict( bar_id=get_caller_qualname(), unit="iB", unit_scale=True, prefix=f"Downloading {asset_name}", ), pbar_kwargs, ) if cache: check_mkdir(assets_dir, **cache_mkdir_kwargs) cache_file = assets_dir / asset_name with open(cache_file, "wb") as f: with ProgressBar(total=file_size, show_progress=show_progress, **pbar_kwargs) as pbar: for chunk in asset_response.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) pbar.update(len(chunk)) return cls.from_json_file(cache_file, release_name=release_name, **kwargs) else: with io.BytesIO() as bytes_io: with ProgressBar(total=file_size, show_progress=show_progress, **pbar_kwargs) as pbar: for chunk in asset_response.iter_content(chunk_size=chunk_size): if chunk: bytes_io.write(chunk) pbar.update(len(chunk)) bytes_ = bytes_io.getvalue() compression = suggest_compression(asset_name) if compression is not None and "compression" not in kwargs: kwargs["compression"] = compression return cls.from_json_bytes(bytes_, release_name=release_name, **kwargs) def find_link( self: VBTAssetT, link: tp.MaybeList[str], mode: str = "end", per_path: bool = False, single_item: bool = True, consolidate: bool = True, allow_empty: bool = False, **kwargs, ) -> tp.MaybeVBTAsset: """Find item(s) corresponding to link(s).""" def _extend_link(link): from urllib.parse import urlparse if not urlparse(link).fragment: if link.endswith("/"): return [link, link[:-1]] return [link, link + "/"] return [link] links = link if mode.lower() in ("exact", "end"): if isinstance(link, str): links = _extend_link(link) elif isinstance(link, list): from itertools import chain links = list(chain(*map(_extend_link, link))) else: raise TypeError("Link must be either string or list") found = self.find(links, path="link", mode=mode, per_path=per_path, single_item=single_item, **kwargs) if isinstance(found, (type(self), list)): if len(found) == 0: if allow_empty: return found raise NoItemFoundError(f"No item matching '{link}'") if single_item and len(found) > 1: if consolidate: top_parents = self.get_top_parent_links(list(found)) if len(top_parents) == 1: for i, d in enumerate(found): if d["link"] == top_parents[0]: if isinstance(found, type(self)): return found.replace(data=[d], single_item=True) return d links_block = "\n".join([d["link"] for d in found]) raise MultipleItemsFoundError(f"Multiple items matching '{link}':\n\n{links_block}") return found @classmethod def minimize_link(cls, link: str, rules: tp.Optional[tp.Dict[str, str]] = None) -> str: """Minimize a single link.""" rules = cls.resolve_setting(rules, "minimize_link_rules", merge=True) for k, v in rules.items(): link = replace(k, v, link, mode="regex") return link def minimize_links(self, rules: tp.Optional[tp.Dict[str, str]] = None) -> tp.MaybeVBTAsset: """Minimize links.""" rules = self.resolve_setting(rules, "minimize_link_rules", merge=True) return self.find_replace(rules, mode="regex") def minimize( self, keys: tp.Optional[tp.List[str]] = None, links: tp.Optional[bool] = None, ) -> tp.MaybeVBTAsset: """Minimize by keeping the most useful information. If `minimize_links` is True, replaces redundant URL prefixes by templates that can be easily substituted later.""" keys = self.resolve_setting(keys, "minimize_keys") links = self.resolve_setting(links, "minimize_links") new_instance = self.find_remove_empty() if links: return new_instance.minimize_links() if keys: new_instance = new_instance.remove(keys, skip_missing=True) return new_instance def select_previous(self: VBTAssetT, link: str, **kwargs) -> VBTAssetT: """Select the previous data item.""" d = self.find_link(link, wrap=False, **kwargs) d_index = self.index(d) new_data = [] if d_index > 0: new_data.append(self.data[d_index - 1]) return self.replace(data=new_data, single_item=True) def select_next(self: VBTAssetT, link: str, **kwargs) -> VBTAssetT: """Select the next data item.""" d = self.find_link(link, wrap=False, **kwargs) d_index = self.index(d) new_data = [] if d_index < len(self.data) - 1: new_data.append(self.data[d_index + 1]) return self.replace(data=new_data, single_item=True) def to_markdown( self, root_metadata_key: tp.Optional[tp.Key] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeVBTAsset: """Convert to Markdown. Uses `VBTAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.ToMarkdownAssetFunc`. Use `root_metadata_key` to provide the root key for the metadata markdown. If `clean_metadata` is True, removes empty fields from the metadata. Arguments in `clean_metadata_kwargs` are passed to `vectorbtpro.utils.knowledge.base_asset_funcs.FindRemoveAssetFunc`, while `dump_metadata_kwargs` are passed to `vectorbtpro.utils.knowledge.base_asset_funcs.DumpAssetFunc`. Last keyword arguments in `kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.to_markdown`.""" return self.apply( "to_markdown", root_metadata_key=root_metadata_key, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, **kwargs, ) @classmethod def links_to_paths( cls, urls: tp.Iterable[str], extension: tp.Optional[str] = None, allow_fragments: bool = True, ) -> tp.List[Path]: """Convert links to corresponding paths.""" from urllib.parse import urlparse url_paths = [] for url in urls: parsed = urlparse(url, allow_fragments=allow_fragments) path_parts = [parsed.netloc] url_path = parsed.path.strip("/") if url_path: parts = url_path.split("/") if parsed.fragment: path_parts.extend(parts) if extension is not None: file_name = parsed.fragment + "." + extension else: file_name = parsed.fragment path_parts.append(file_name) else: if len(parts) > 1: path_parts.extend(parts[:-1]) last_part = parts[-1] if extension is not None: file_name = last_part + "." + extension else: file_name = last_part path_parts.append(file_name) else: if parsed.fragment: if extension is not None: file_name = parsed.fragment + "." + extension else: file_name = parsed.fragment path_parts.append(file_name) else: if extension is not None: path_parts.append("index." + extension) else: path_parts.append("index") url_paths.append(Path(os.path.join(*path_parts))) return url_paths def save_to_markdown( self, cache: tp.Optional[bool] = None, cache_dir: tp.Optional[tp.PathLike] = None, cache_mkdir_kwargs: tp.KwargsLike = None, clear_cache: tp.Optional[bool] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> Path: """Save to Markdown files. If `cache` is True, uses the cache directory (`markdown_dir` in settings). Otherwise, creates a temporary directory. If `clear_cache` is True, deletes any existing directory before creating a new one. Returns the path of the directory where Markdown files are stored. Keyword arguments are passed to `vectorbtpro.utils.knowledge.custom_asset_funcs.ToMarkdownAssetFunc`. Last keyword arguments in `kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.to_markdown`.""" import tempfile from vectorbtpro.utils.knowledge.custom_asset_funcs import ToMarkdownAssetFunc cache = self.resolve_setting(cache, "cache") markdown_dir = self.resolve_setting(cache_dir, "markdown_dir") cache_mkdir_kwargs = self.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True) clear_cache = self.resolve_setting(clear_cache, "clear_cache") show_progress = self.resolve_setting(show_progress, "show_progress") pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True) template_context = self.resolve_setting(template_context, "template_context", merge=True) if cache: if self.release_name: template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context) if isinstance(markdown_dir, CustomTemplate): cache_dir = self.get_setting("cache_dir") if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) release_dir = self.get_setting("release_dir") if isinstance(release_dir, CustomTemplate): release_dir = release_dir.substitute(template_context, eval_id="release_dir") template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context) markdown_dir = markdown_dir.substitute(template_context, eval_id="markdown_dir") if markdown_dir.exists(): if clear_cache: remove_dir(markdown_dir, missing_ok=True, with_contents=True) check_mkdir(markdown_dir, **cache_mkdir_kwargs) else: markdown_dir = Path(tempfile.mkdtemp(prefix=get_caller_qualname() + "_")) link_map = {d["link"]: dict(d) for d in self.data} url_paths = self.links_to_paths(link_map.keys(), extension="md") url_file_map = dict(zip(link_map.keys(), [markdown_dir / p for p in url_paths])) _, kwargs = ToMarkdownAssetFunc.prepare(**kwargs) if show_progress is None: show_progress = not self.single_item prefix = get_caller_qualname().split(".")[-1] pbar_kwargs = flat_merge_dicts( dict( bar_id=get_caller_qualname(), prefix=prefix, ), pbar_kwargs, ) with ProgressBar(total=len(self.data), show_progress=show_progress, **pbar_kwargs) as pbar: for d in self.data: if not url_file_map[d["link"]].exists(): markdown_content = ToMarkdownAssetFunc.call(d, **kwargs) check_mkdir(url_file_map[d["link"]].parent, mkdir=True) with open(url_file_map[d["link"]], "w", encoding="utf-8") as f: f.write(markdown_content) pbar.update() return markdown_dir def to_html( self: VBTAssetT, root_metadata_key: tp.Optional[tp.Key] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, format_html_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeVBTAsset: """Convert to HTML. Uses `VBTAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.ToHTMLAssetFunc`. Arguments in `format_html_kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.format_html`. Last keyword arguments in `kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.to_html`. For other arguments, see `VBTAsset.to_markdown`.""" return self.apply( "to_html", root_metadata_key=root_metadata_key, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, format_html_kwargs=format_html_kwargs, **kwargs, ) @classmethod def get_top_parent_links(cls, data: tp.List) -> tp.List[str]: """Get links of top parents in data.""" link_map = {d["link"]: dict(d) for d in data} top_parents = [] for d in data: if d.get("parent", None) is None or d["parent"] not in link_map: top_parents.append(d["link"]) return top_parents @property def top_parent_links(self) -> tp.List[str]: """Get links of top parents.""" return self.get_top_parent_links(self.data) @classmethod def replace_urls_in_html(cls, html: str, url_map: dict) -> str: """Replace URLs in attributes based on a provided mapping.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("bs4") from bs4 import BeautifulSoup from urllib.parse import urlparse, urlunparse soup = BeautifulSoup(html, "html.parser") for a_tag in soup.find_all("a", href=True): original_href = a_tag["href"] if original_href in url_map: a_tag["href"] = url_map[original_href] else: try: parsed_href = urlparse(original_href) base_url = urlunparse(parsed_href._replace(fragment="")) if base_url in url_map: new_base_url = url_map[base_url] new_parsed = urlparse(new_base_url) new_parsed = new_parsed._replace(fragment=parsed_href.fragment) new_href = urlunparse(new_parsed) a_tag["href"] = new_href except ValueError: pass return str(soup) def save_to_html( self, cache: tp.Optional[bool] = None, cache_dir: tp.Optional[tp.PathLike] = None, cache_mkdir_kwargs: tp.KwargsLike = None, clear_cache: tp.Optional[bool] = None, show_progress: tp.Optional[bool] = None, pbar_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, return_url_map: bool = False, **kwargs, ) -> tp.Union[Path, tp.Tuple[Path, dict]]: """Save to HTML files. Opens the web browser. Also, returns the path of the directory where HTML files are stored, and if `return_url_map` is True also returns the link->file map. In addition, if there are multiple top-level parents, creates an index page. If `cache` is True, uses the cache directory (`html_dir` in settings). Otherwise, creates a temporary directory. If `clear_cache` is True, deletes any existing directory before creating a new one. Keyword arguments are passed to `vectorbtpro.utils.knowledge.custom_asset_funcs.ToHTMLAssetFunc`.""" import tempfile from vectorbtpro.utils.knowledge.custom_asset_funcs import ToHTMLAssetFunc cache = self.resolve_setting(cache, "cache") html_dir = self.resolve_setting(cache_dir, "html_dir") cache_mkdir_kwargs = self.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True) clear_cache = self.resolve_setting(clear_cache, "clear_cache") show_progress = self.resolve_setting(show_progress, "show_progress") pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True) if cache: if self.release_name: template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context) if isinstance(html_dir, CustomTemplate): cache_dir = self.get_setting("cache_dir") if isinstance(cache_dir, CustomTemplate): cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir") template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context) release_dir = self.get_setting("release_dir") if isinstance(release_dir, CustomTemplate): release_dir = release_dir.substitute(template_context, eval_id="release_dir") template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context) html_dir = html_dir.substitute(template_context, eval_id="html_dir") if html_dir.exists(): if clear_cache: remove_dir(html_dir, missing_ok=True, with_contents=True) check_mkdir(html_dir, **cache_mkdir_kwargs) else: html_dir = Path(tempfile.mkdtemp(prefix=get_caller_qualname() + "_")) link_map = {d["link"]: dict(d) for d in self.data} top_parents = self.top_parent_links if len(top_parents) > 1: link_map["/"] = {} url_paths = self.links_to_paths(link_map.keys(), extension="html") url_file_map = dict(zip(link_map.keys(), [html_dir / p for p in url_paths])) url_map = {k: "file://" + str(v.resolve()) for k, v in url_file_map.items()} _, kwargs = ToHTMLAssetFunc.prepare(**kwargs) if len(top_parents) > 1: entry_link = "/" if not url_file_map[entry_link].exists(): html = ToHTMLAssetFunc.call([link_map[link] for link in top_parents], **kwargs) html = self.replace_urls_in_html(html, url_map) check_mkdir(url_file_map[entry_link].parent, mkdir=True) with open(url_file_map[entry_link], "w", encoding="utf-8") as f: f.write(html) if show_progress is None: show_progress = not self.single_item prefix = get_caller_qualname().split(".")[-1] pbar_kwargs = flat_merge_dicts( dict( bar_id=get_caller_qualname(), prefix=prefix, ), pbar_kwargs, ) with ProgressBar(total=len(self.data), show_progress=show_progress, **pbar_kwargs) as pbar: for d in self.data: if not url_file_map[d["link"]].exists(): html = ToHTMLAssetFunc.call(d, **kwargs) html = self.replace_urls_in_html(html, url_map) check_mkdir(url_file_map[d["link"]].parent, mkdir=True) with open(url_file_map[d["link"]], "w", encoding="utf-8") as f: f.write(html) pbar.update() if return_url_map: return html_dir, url_map return html_dir def browse( self, entry_link: tp.Optional[str] = None, find_kwargs: tp.KwargsLike = None, open_browser: tp.Optional[bool] = None, **kwargs, ) -> Path: """Browse one or more HTML pages. Opens the web browser. Also, returns the path of the directory where HTML files are stored. Use `entry_link` to specify the link of the page that should be displayed first. If `entry_link` is None and there are multiple top-level parents, displays them as an index. If it's not None, it will be matched using `VBTAsset.find_link` and `find_kwargs`. Keyword arguments are passed to `PagesAsset.save_to_html`.""" open_browser = self.resolve_setting(open_browser, "open_browser") if entry_link is None: if len(self.data) == 1: entry_link = self.data[0]["link"] else: top_parents = self.top_parent_links if len(top_parents) == 1: entry_link = top_parents[0] else: entry_link = "/" else: if find_kwargs is None: find_kwargs = {} d = self.find_link(entry_link, wrap=False, **find_kwargs) entry_link = d["link"] html_dir, url_map = self.save_to_html(return_url_map=True, **kwargs) if open_browser: import webbrowser webbrowser.open(url_map[entry_link]) return html_dir def display( self, link: tp.Optional[str] = None, find_kwargs: tp.KwargsLike = None, open_browser: tp.Optional[bool] = None, html_template: tp.Optional[str] = None, style_extras: tp.Optional[tp.MaybeList[str]] = None, head_extras: tp.Optional[tp.MaybeList[str]] = None, body_extras: tp.Optional[tp.MaybeList[str]] = None, invert_colors: tp.Optional[bool] = None, title: str = "", template_context: tp.KwargsLike = None, **kwargs, ) -> Path: """Display as an HTML page. If there are multiple HTML pages, shows them as iframes inside a parent HTML page with pagination. For this, uses `vectorbtpro.utils.knowledge.formatting.FormatHTML`. Opens the web browser. Also, returns the path of the temporary HTML file. !!! note The file __won't__ be deleted automatically. """ import tempfile open_browser = self.resolve_setting(open_browser, "open_browser", sub_path="display") if link is not None: if find_kwargs is None: find_kwargs = {} instance = self.find_link(link, **find_kwargs) else: instance = self html = instance.to_html(wrap=False, single_item=True, **kwargs) if len(instance) > 1: from vectorbtpro.utils.config import ExtSettingsPath encoded_pages = map(lambda x: base64.b64encode(x.encode("utf-8")).decode("ascii"), html) pages = "[\n" + ",\n".join(f' "{page}"' for page in encoded_pages) + "\n]" ext_settings_paths = [] for cls_ in type(self).__mro__[::-1]: if issubclass(cls_, VBTAsset): if not isinstance(cls_._settings_path, str): raise TypeError("_settings_path for VBTAsset and its subclasses must be a string") ext_settings_paths.append((FormatHTML, cls_._settings_path + ".display")) with ExtSettingsPath(ext_settings_paths): html = FormatHTML( html_template=html_template, style_extras=style_extras, head_extras=head_extras, body_extras=body_extras, invert_colors=invert_colors, auto_scroll=False, ).format_html(title=title, pages=pages) with tempfile.NamedTemporaryFile( "w", encoding="utf-8", prefix=get_caller_qualname() + "_", suffix=".html", delete=False, ) as f: f.write(html) file_path = Path(f.name) if open_browser: import webbrowser webbrowser.open("file://" + str(file_path.resolve())) return file_path @classmethod def prepare_mention_target( cls, target: str, as_code: bool = False, as_regex: bool = True, allow_prefix: bool = False, allow_suffix: bool = False, ) -> str: """Prepare a mention target.""" if as_regex: escaped_target = re.escape(target) new_target = "" if not allow_prefix and re.match(r"\w", target[0]): new_target += r"(? tp.List[str]: """Split a class name constituent parts.""" return re.findall(r"[A-Z]+(?=[A-Z][a-z]|$)|[A-Z][a-z]+", name) @classmethod def get_class_abbrs(cls, name: str) -> tp.List[str]: """Convert a class name to snake case and its abbreviated versions.""" from itertools import product parts = cls.split_class_name(name) replacement_lists = [] for i, part in enumerate(parts): replacements = [part.lower()] if i == 0 and f"{part}_" in class_abbr_config: replacements.extend(class_abbr_config[f"{part}_"]) if part in class_abbr_config: replacements.extend(class_abbr_config[part]) replacement_lists.append(replacements) all_combinations = list(product(*replacement_lists)) snake_case_names = ["_".join(combo) for combo in all_combinations] return snake_case_names @classmethod def generate_refname_targets( cls, refname: str, resolve: bool = True, incl_shortcuts: tp.Optional[bool] = None, incl_shortcut_access: tp.Optional[bool] = None, incl_shortcut_call: tp.Optional[bool] = None, incl_instances: tp.Optional[bool] = None, as_code: tp.Optional[bool] = None, as_regex: tp.Optional[bool] = None, allow_prefix: tp.Optional[bool] = None, allow_suffix: tp.Optional[bool] = None, ) -> tp.List[str]: """Generate reference name targets. If `incl_shortcuts` is True, includes shortcuts found in `import vectorbtpro as vbt`. In addition, if `incl_shortcut_access` is True and the object is a class or module, includes a version with attribute access, and if `incl_shortcut_call` is True and the object is callable, includes a version that is being called. If `incl_instances` is True, includes typical short names of classes, which include the snake-cased class name and mapped name parts found in `class_abbr_config`. Prepares each mention target with `VBTAsset.prepare_mention_target`.""" from vectorbtpro.utils.module_ import annotate_refname_parts import vectorbtpro as vbt incl_shortcuts = cls.resolve_setting(incl_shortcuts, "incl_shortcuts") incl_shortcut_access = cls.resolve_setting(incl_shortcut_access, "incl_shortcut_access") incl_shortcut_call = cls.resolve_setting(incl_shortcut_call, "incl_shortcut_call") incl_instances = cls.resolve_setting(incl_instances, "incl_instances") as_code = cls.resolve_setting(as_code, "as_code") as_regex = cls.resolve_setting(as_regex, "as_regex") allow_prefix = cls.resolve_setting(allow_prefix, "allow_prefix") allow_suffix = cls.resolve_setting(allow_suffix, "allow_suffix") def _prepare_target( target, _as_code=as_code, _as_regex=as_regex, _allow_prefix=allow_prefix, _allow_suffix=allow_suffix, ): return cls.prepare_mention_target( target, as_code=_as_code, as_regex=_as_regex, allow_prefix=_allow_prefix, allow_suffix=_allow_suffix, ) targets = set() new_target = _prepare_target(refname) targets.add(new_target) refname_parts = refname.split(".") if resolve: annotated_parts = annotate_refname_parts(refname) if len(annotated_parts) >= 2 and isinstance(annotated_parts[-2]["obj"], type): cls_refname = ".".join(refname_parts[:-1]) cls_aliases = {annotated_parts[-2]["name"]} attr_aliases = set() for k, v in vbt.__dict__.items(): v_refname = prepare_refname(v, raise_error=False) if v_refname is not None: if v_refname == cls_refname: cls_aliases.add(k) elif v_refname == refname: attr_aliases.add(k) if incl_shortcuts: new_target = _prepare_target("vbt." + k) targets.add(new_target) if incl_shortcuts: for cls_alias in cls_aliases: new_target = _prepare_target(cls_alias + "." + annotated_parts[-1]["name"]) targets.add(new_target) for attr_alias in attr_aliases: if incl_shortcut_call and callable(annotated_parts[-1]["obj"]): new_target = _prepare_target(attr_alias + "(") targets.add(new_target) if incl_instances: for cls_alias in cls_aliases: for class_abbr in cls.get_class_abbrs(cls_alias): new_target = _prepare_target(class_abbr + "." + annotated_parts[-1]["name"]) targets.add(new_target) else: if len(refname_parts) >= 2: module_name = ".".join(refname_parts[:-1]) attr_name = refname_parts[-1] new_target = _prepare_target("from {} import {}".format(module_name, attr_name)) targets.add(new_target) aliases = {annotated_parts[-1]["name"]} for k, v in vbt.__dict__.items(): v_refname = prepare_refname(v, raise_error=False) if v_refname is not None: if v_refname == refname: aliases.add(k) if incl_shortcuts: new_target = _prepare_target("vbt." + k) targets.add(new_target) if incl_shortcuts: for alias in aliases: if incl_shortcut_access and isinstance(annotated_parts[-1]["obj"], (type, ModuleType)): new_target = _prepare_target(alias + ".") targets.add(new_target) if incl_shortcut_call and callable(annotated_parts[-1]["obj"]): new_target = _prepare_target(alias + "(") targets.add(new_target) if incl_instances and isinstance(annotated_parts[-1]["obj"], type): for alias in aliases: for class_abbr in cls.get_class_abbrs(alias): new_target = _prepare_target(class_abbr + " =") targets.add(new_target) new_target = _prepare_target(class_abbr + ".") targets.add(new_target) return sorted(targets) def generate_mention_targets( self, obj: tp.MaybeList, *, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, incl_base_attr: tp.Optional[bool] = None, incl_shortcuts: tp.Optional[bool] = None, incl_shortcut_access: tp.Optional[bool] = None, incl_shortcut_call: tp.Optional[bool] = None, incl_instances: tp.Optional[bool] = None, as_code: tp.Optional[bool] = None, as_regex: tp.Optional[bool] = None, allow_prefix: tp.Optional[bool] = None, allow_suffix: tp.Optional[bool] = None, ) -> tp.List[str]: """Generate mention targets. Prepares the object reference with `vectorbtpro.utils.module_.prepare_refname`. If an attribute is provided, checks whether the attribute is defined by the object itself or by one of its base classes. If the latter and `incl_base_attr` is True, generates reference name targets for both the object attribute and the base class attribute. Generates reference name targets with `VBTAsset.generate_refname_targets`.""" from vectorbtpro.utils.module_ import prepare_refname incl_base_attr = self.resolve_setting(incl_base_attr, "incl_base_attr") targets = [] if not isinstance(obj, list): objs = [obj] else: objs = obj for obj in objs: obj_refname = prepare_refname(obj, module=module, resolve=resolve) if attr is not None: checks.assert_instance_of(attr, str, arg_name="attr") if isinstance(obj, tuple): attr_obj = (*obj, attr) else: attr_obj = (obj, attr) base_attr_refname = prepare_refname(attr_obj, module=module, resolve=resolve) obj_refname += "." + attr if base_attr_refname == obj_refname: obj_refname = base_attr_refname base_attr_refname = None else: base_attr_refname = None targets.extend( self.generate_refname_targets( obj_refname, resolve=resolve, incl_shortcuts=incl_shortcuts, incl_shortcut_access=incl_shortcut_access, incl_shortcut_call=incl_shortcut_call, incl_instances=incl_instances, as_code=as_code, as_regex=as_regex, allow_prefix=allow_prefix, allow_suffix=allow_suffix, ) ) if incl_base_attr and base_attr_refname is not None: targets.extend( self.generate_refname_targets( base_attr_refname, resolve=resolve, incl_shortcuts=incl_shortcuts, incl_shortcut_access=incl_shortcut_access, incl_shortcut_call=incl_shortcut_call, incl_instances=incl_instances, as_code=as_code, as_regex=as_regex, allow_prefix=allow_prefix, allow_suffix=allow_suffix, ) ) seen = set() targets = [x for x in targets if not (x in seen or seen.add(x))] return targets @classmethod def merge_mention_targets(cls, targets: tp.List[str], as_regex: bool = True) -> str: """Merge mention targets into a single regular expression.""" if as_regex: prefixed_targets = [] non_prefixed_targets = [] common_prefix = r"(? tp.MaybeVBTAsset: """Find mentions of a VBT object. Generates mention targets with `VBTAsset.generate_mention_targets`. Provide custom mentions in `incl_custom`. If regular expressions are provided, set `is_custom_regex` to True. If `as_code` is True, uses `VBTAsset.find_code`, otherwise, uses `VBTAsset.find`. If `as_regex` is True, search is refined by using regular expressions. For instance, `vbt.PF` may match `vbt.PFO` if RegEx is not used. If `merge_targets`, uses `VBTAsset.merge_mention_targets` to reduce the number of targets. Sets `as_regex` to True if False (but after the targets were generated).""" as_code = self.resolve_setting(as_code, "as_code") as_regex = self.resolve_setting(as_regex, "as_regex") allow_prefix = self.resolve_setting(allow_prefix, "allow_prefix") allow_suffix = self.resolve_setting(allow_suffix, "allow_suffix") merge_targets = self.resolve_setting(merge_targets, "merge_targets") mention_targets = self.generate_mention_targets( obj, attr=attr, module=module, resolve=resolve, incl_shortcuts=incl_shortcuts, incl_shortcut_access=incl_shortcut_access, incl_shortcut_call=incl_shortcut_call, incl_instances=incl_instances, as_code=as_code, as_regex=as_regex, allow_prefix=allow_prefix, allow_suffix=allow_suffix, ) if incl_custom: def _prepare_target( target, _as_code=as_code, _as_regex=as_regex, _allow_prefix=allow_prefix, _allow_suffix=allow_suffix, ): return self.prepare_mention_target( target, as_code=_as_code, as_regex=_as_regex, allow_prefix=_allow_prefix, allow_suffix=_allow_suffix, ) if isinstance(incl_custom, str): incl_custom = [incl_custom] for custom in incl_custom: new_target = _prepare_target(custom, _as_regex=is_custom_regex) if new_target not in mention_targets: mention_targets.append(new_target) if merge_targets: mention_targets = self.merge_mention_targets(mention_targets, as_regex=as_regex) as_regex = True if as_code: mentions_asset = self.find_code( mention_targets, escape_target=not as_regex, path=path, per_path=per_path, return_type=return_type, **kwargs, ) elif as_regex: mentions_asset = self.find( mention_targets, mode="regex", path=path, per_path=per_path, return_type=return_type, **kwargs, ) else: mentions_asset = self.find( mention_targets, path=path, per_path=per_path, return_type=return_type, **kwargs, ) return mentions_asset @classmethod def resolve_spec_settings_path(cls) -> dict: """Resolve specialized settings paths.""" spec_settings_path = {} for cls_ in cls.__mro__[::-1]: if issubclass(cls_, VBTAsset): if not isinstance(cls_._settings_path, str): raise TypeError("_settings_path for VBTAsset and its subclasses must be a string") if "knowledge" not in spec_settings_path: spec_settings_path["knowledge"] = [] spec_settings_path["knowledge"].append(cls_._settings_path) if "knowledge.chat" not in spec_settings_path: spec_settings_path["knowledge.chat"] = [] spec_settings_path["knowledge.chat"].append(cls_._settings_path + ".chat") return spec_settings_path def embed(self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.Optional[tp.MaybeVBTAsset]: template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context) spec_settings_path = self.resolve_spec_settings_path() if spec_settings_path: with SpecSettingsPath(spec_settings_path): return KnowledgeAsset.embed(self, *args, template_context=template_context, **kwargs) return KnowledgeAsset.embed(self, *args, template_context=template_context, **kwargs) def rank(self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.MaybeVBTAsset: template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context) spec_settings_path = self.resolve_spec_settings_path() if spec_settings_path: with SpecSettingsPath(spec_settings_path): return KnowledgeAsset.rank(self, *args, template_context=template_context, **kwargs) return KnowledgeAsset.rank(self, *args, template_context=template_context, **kwargs) def create_chat(self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.Completions: template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context) spec_settings_path = self.resolve_spec_settings_path() if spec_settings_path: with SpecSettingsPath(spec_settings_path): return KnowledgeAsset.create_chat(self, *args, template_context=template_context, **kwargs) return KnowledgeAsset.create_chat(self, *args, template_context=template_context, **kwargs) @hybrid_method def chat(cls_or_self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.MaybeChatOutput: if not isinstance(cls_or_self, type): template_context = flat_merge_dicts(dict(release_name=cls_or_self.release_name), template_context) spec_settings_path = cls_or_self.resolve_spec_settings_path() if spec_settings_path: with SpecSettingsPath(spec_settings_path): return KnowledgeAsset.chat.__func__(cls_or_self, *args, template_context=template_context, **kwargs) return KnowledgeAsset.chat.__func__(cls_or_self, *args, template_context=template_context, **kwargs) PagesAssetT = tp.TypeVar("PagesAssetT", bound="PagesAsset") class PagesAsset(VBTAsset): """Class for working with website pages. Has the following fields: * link: URL of the page (without fragment), such as "https://vectorbt.pro/features/data/", or URL of the heading (with fragment), such as "https://vectorbt.pro/features/data/#trading-view" * parent: URL of the parent page or heading. For example, a heading 1 is a parent of a heading 2. * children: List of URLs of the child pages and/or headings. For example, a heading 2 is a child of a heading 1. * name: Name of the page or heading. Within the API, the name of the object that the heading represents, such as "Portfolio.from_signals". * type: Type of the page or heading, such as "page", "heading 1", "heading 2", etc. * icon: Icon, such as "material-brain" * tags: List of tags, such as ["portfolio", "records"] * content: String content of the page or heading. Can be None in pages that solely redirect. * obj_type: Within the API, the type of the object that the heading represents, such as "property" * github_link: Within the API, the URL to the source code of the object that the heading represents For defaults, see `assets.pages` in `vectorbtpro._settings.knowledge`.""" _settings_path: tp.SettingsPath = "knowledge.assets.pages" def descend_links(self: PagesAssetT, links: tp.List[str], **kwargs) -> PagesAssetT: """Descend links by removing redundant ones. Only headings are descended.""" redundant_links = set() new_data = {} for link in links: if link in redundant_links: continue descendant_headings = self.select_descendant_headings(link, incl_link=True) for d in descendant_headings: if d["link"] != link: redundant_links.add(d["link"]) new_data[d["link"]] = d for link in links: if link in redundant_links and link in new_data: del new_data[link] return self.replace(data=list(new_data.values()), **kwargs) def aggregate_links( self: PagesAssetT, links: tp.List[str], aggregate_kwargs: tp.KwargsLike = None, **kwargs, ) -> PagesAssetT: """Aggregate links by removing redundant ones. Only headings are aggregated.""" if aggregate_kwargs is None: aggregate_kwargs = {} redundant_links = set() new_data = {} for link in links: if link in redundant_links: continue descendant_headings = self.select_descendant_headings(link, incl_link=True) for d in descendant_headings: if d["link"] != link: redundant_links.add(d["link"]) descendant_headings = descendant_headings.aggregate(**aggregate_kwargs) new_data[link] = descendant_headings[0] for link in links: if link in redundant_links and link in new_data: del new_data[link] return self.replace(data=list(new_data.values()), **kwargs) def find_page( self: PagesAssetT, link: tp.MaybeList[str], aggregate: bool = False, aggregate_kwargs: tp.KwargsLike = None, incl_descendants: bool = False, single_item: bool = True, **kwargs, ) -> tp.MaybePagesAsset: """Find the page(s) corresponding to link(s). Keyword arguments are passed to `VBTAsset.find_link`.""" found = self.find_link(link, single_item=single_item, **kwargs) if not isinstance(found, (type(self), list)): return found if aggregate: return self.aggregate_links( [d["link"] for d in found], aggregate_kwargs=aggregate_kwargs, single_item=single_item, ) if incl_descendants: return self.descend_links( [d["link"] for d in found], single_item=single_item, ) return found def find_refname( self, refname: tp.MaybeList[str], **kwargs, ) -> tp.MaybePagesAsset: """Find the page corresponding to a reference.""" if isinstance(refname, list): link = list(map(lambda x: f"#({re.escape(x)})$", refname)) else: link = f"#({re.escape(refname)})$" return self.find_page(link, mode="regex", **kwargs) def find_obj( self, obj: tp.Any, *, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, **kwargs, ) -> tp.MaybePagesAsset: """Find the page corresponding a single (internal) object or reference name. Prepares the reference with `vectorbtpro.utils.module_.prepare_refname`.""" if attr is not None: checks.assert_instance_of(attr, str, arg_name="attr") if isinstance(obj, tuple): obj = (*obj, attr) else: obj = (obj, attr) refname = prepare_refname(obj, module=module, resolve=resolve) return self.find_refname(refname, **kwargs) @classmethod def parse_content_links(cls, content: str) -> tp.List[str]: """Parse all links from a content.""" link_pattern = r'(? tp.Optional[str]: """Parse the reference name from a link.""" if "/api/" not in link: return None if "#" in link: refname = link.split("#")[1] if refname.startswith("vectorbtpro"): return refname return None return "vectorbtpro." + ".".join(link.split("/api/")[1].strip("/").split("/")) @classmethod def is_link_module(cls, link: str) -> bool: """Return whether a link is a module.""" if "/api/" not in link: return False if "#" not in link: return True refname = link.split("#")[1] if "/".join(refname.split(".")) in link: return True return False def find_obj_api( self, obj: tp.MaybeList, *, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, use_parent: tp.Optional[bool] = None, use_base_parents: tp.Optional[bool] = None, use_ref_parents: tp.Optional[bool] = None, incl_bases: tp.Union[None, bool, int] = None, incl_ancestors: tp.Union[None, bool, int] = None, incl_base_ancestors: tp.Union[None, bool, int] = None, incl_refs: tp.Union[None, bool, int] = None, incl_descendants: tp.Optional[bool] = None, incl_ancestor_descendants: tp.Optional[bool] = None, incl_ref_descendants: tp.Optional[bool] = None, aggregate: tp.Optional[bool] = None, aggregate_ancestors: tp.Optional[bool] = None, aggregate_refs: tp.Optional[bool] = None, aggregate_kwargs: tp.KwargsLike = None, topo_sort: tp.Optional[bool] = None, return_refname_graph: bool = False, ) -> tp.Union[PagesAssetT, tp.Tuple[PagesAssetT, dict]]: """Find API pages and headings relevant to object(s). Prepares the object reference with `vectorbtpro.utils.module_.prepare_refname`. If `incl_bases` is True, extends the asset with the base classes/attributes if the object is a class/attribute. For instance, `vectorbtpro.portfolio.base.Portfolio` has `vectorbtpro.generic.analyzable.Analyzable` as one of its base classes. It can also be an integer indicating the maximum inheritance level. If `obj` is a module, then bases are sub-modules. If `incl_ancestors` is True, extends the asset with the ancestors of the object. For instance, `vectorbtpro.portfolio.base.Portfolio` has `vectorbtpro.portfolio.base` as its ancestor. It can also be an integer indicating the maximum inheritance level. Provide `incl_base_ancestors` to override `incl_ancestors` for base classes/attributes. If `incl_refs` is True, extends the asset with the references found in the content of the object. It can also be an integer indicating the maximum reference level. Defaults to False for modules and classes, and True otherwise. If resolution of reference names is disabled, defaults to False. If `incl_descendants` is True, extends the asset page or heading with any descendant headings. Provide `incl_ancestor_descendants` and `incl_ref_descendants` to override `incl_descendants` for ancestors and references respectively. If `aggregate` is True, aggregates any descendant headings into pages for this object and all base classes/attributes. Provide `aggregate_ancestors` and `aggregate_refs` to override `aggregate` for ancestors and references respectively. If `topo_sort` is True, creates a topological graph from all reference names and sorts pages and headings based on this graph. Use `return_refname_graph` to True to also return the graph.""" from vectorbtpro.utils.module_ import prepare_refname, annotate_refname_parts incl_bases = self.resolve_setting(incl_bases, "incl_bases") incl_ancestors = self.resolve_setting(incl_ancestors, "incl_ancestors") incl_base_ancestors = self.resolve_setting(incl_base_ancestors, "incl_base_ancestors") incl_refs = self.resolve_setting(incl_refs, "incl_refs") incl_descendants = self.resolve_setting(incl_descendants, "incl_descendants") incl_ancestor_descendants = self.resolve_setting(incl_ancestor_descendants, "incl_ancestor_descendants") incl_ref_descendants = self.resolve_setting(incl_ref_descendants, "incl_ref_descendants") aggregate = self.resolve_setting(aggregate, "aggregate") aggregate_ancestors = self.resolve_setting(aggregate_ancestors, "aggregate_ancestors") aggregate_refs = self.resolve_setting(aggregate_refs, "aggregate_refs") topo_sort = self.resolve_setting(topo_sort, "topo_sort") base_refnames = [] base_refnames_set = set() if not isinstance(obj, list): objs = [obj] else: objs = obj for obj in objs: if attr is not None: checks.assert_instance_of(attr, str, arg_name="attr") if isinstance(obj, tuple): obj = (*obj, attr) else: obj = (obj, attr) obj_refname = prepare_refname(obj, module=module, resolve=resolve) refname_graph = defaultdict(list) if resolve: annotated_parts = annotate_refname_parts(obj_refname) if isinstance(annotated_parts[-1]["obj"], ModuleType): _module = annotated_parts[-1]["obj"] _cls = None _attr = None elif isinstance(annotated_parts[-1]["obj"], type): _module = None _cls = annotated_parts[-1]["obj"] _attr = None elif len(annotated_parts) >= 2 and isinstance(annotated_parts[-2]["obj"], type): _module = None _cls = annotated_parts[-2]["obj"] _attr = annotated_parts[-1]["name"] else: _module = None _cls = None _attr = None if use_parent is None: use_parent = _cls is not None and _attr is None if not aggregate and not incl_descendants: use_parent = False use_base_parents = False if incl_refs is None: incl_refs = _module is None and _cls is None if _cls is not None and incl_bases: level_classes = defaultdict(set) visited = set() queue = deque([(_cls, 0)]) while queue: current_cls, current_level = queue.popleft() if current_cls in visited: continue visited.add(current_cls) level_classes[current_level].add(current_cls) for base in current_cls.__bases__: queue.append((base, current_level + 1)) mro = inspect.getmro(_cls) classes = [] levels = list(level_classes.keys()) if not isinstance(incl_bases, bool): if isinstance(incl_bases, int): levels = levels[: incl_bases + 1] else: raise TypeError(f"Invalid incl_bases: {incl_bases}") for level in levels: classes.extend([_cls for _cls in mro if _cls in level_classes[level]]) for c in classes: if c.__module__.split(".")[0] != "vectorbtpro": continue if _attr is not None: if not hasattr(c, _attr): continue refname = prepare_refname((c, _attr)) else: refname = prepare_refname(c) if (use_parent and refname == obj_refname) or use_base_parents: refname = ".".join(refname.split(".")[:-1]) if refname not in base_refnames_set: base_refnames.append(refname) base_refnames_set.add(refname) for b in c.__bases__: if b.__module__.split(".")[0] == "vectorbtpro": if _attr is not None: if not hasattr(b, _attr): continue b_refname = prepare_refname((b, _attr)) else: b_refname = prepare_refname(b) if use_base_parents: b_refname = ".".join(b_refname.split(".")[:-1]) if refname != b_refname: refname_graph[refname].append(b_refname) elif _module is not None and hasattr(_module, "__path__") and incl_bases: base_refnames.append(_module.__name__) base_refnames_set.add(_module.__name__) refname_level = {} refname_level[_module.__name__] = 0 for _, refname, _ in pkgutil.walk_packages(_module.__path__, prefix=f"{_module.__name__}."): if refname not in base_refnames_set: parent_refname = ".".join(refname.split(".")[:-1]) if not isinstance(incl_bases, bool): if isinstance(incl_bases, int): if refname_level[parent_refname] + 1 > incl_bases: continue else: raise TypeError(f"Invalid incl_bases: {incl_bases}") base_refnames.append(refname) base_refnames_set.add(refname) refname_level[refname] = refname_level[parent_refname] + 1 if parent_refname != refname: refname_graph[parent_refname].append(refname) else: base_refnames.append(obj_refname) base_refnames_set.add(obj_refname) else: if incl_refs is None: incl_refs = False base_refnames.append(obj_refname) base_refnames_set.add(obj_refname) api_asset = self.find_refname( base_refnames, single_item=False, incl_descendants=incl_descendants, aggregate=aggregate, aggregate_kwargs=aggregate_kwargs, allow_empty=True, wrap=True, ) if len(api_asset) == 0: return api_asset if not topo_sort: refname_indices = {refname: [] for refname in base_refnames} remaining_indices = [] for i, d in enumerate(api_asset): refname = self.parse_link_refname(d["link"]) if refname is not None: while refname not in refname_indices: if not refname: break refname = ".".join(refname.split(".")[:-1]) if refname: refname_indices[refname].append(i) else: remaining_indices.append(i) get_indices = [i for v in refname_indices.values() for i in v] + remaining_indices api_asset = api_asset.get_items(get_indices) if incl_ancestors or incl_refs: refnames_aggregated = {} for d in api_asset: refname = self.parse_link_refname(d["link"]) if refname is not None: refnames_aggregated[refname] = aggregate to_ref_api_asset = api_asset if incl_ancestors: anc_refnames = [] anc_refnames_set = set(refnames_aggregated.keys()) for d in api_asset: child_refname = refname = self.parse_link_refname(d["link"]) if refname is not None: if incl_base_ancestors or refname == obj_refname: refname = ".".join(refname.split(".")[:-1]) anc_level = 1 while refname: if isinstance(incl_base_ancestors, bool) or refname == obj_refname: if not isinstance(incl_ancestors, bool): if isinstance(incl_ancestors, int): if anc_level > incl_ancestors: break else: raise TypeError(f"Invalid incl_ancestors: {incl_ancestors}") else: if not isinstance(incl_base_ancestors, bool): if isinstance(incl_base_ancestors, int): if anc_level > incl_base_ancestors: break else: raise TypeError(f"Invalid incl_base_ancestors: {incl_base_ancestors}") if refname not in anc_refnames_set: anc_refnames.append(refname) anc_refnames_set.add(refname) if refname != child_refname: refname_graph[refname].append(child_refname) child_refname = refname refname = ".".join(refname.split(".")[:-1]) anc_level += 1 anc_api_asset = self.find_refname( anc_refnames, single_item=False, incl_descendants=incl_ancestor_descendants, aggregate=aggregate_ancestors, aggregate_kwargs=aggregate_kwargs, allow_empty=True, wrap=True, ) if aggregate_ancestors or incl_ancestor_descendants: obj_index = None for i, d in enumerate(api_asset): d_refname = self.parse_link_refname(d["link"]) if d_refname == obj_refname: obj_index = i break if obj_index is not None: del api_asset[obj_index] for d in anc_api_asset: refname = self.parse_link_refname(d["link"]) if refname is not None: refnames_aggregated[refname] = aggregate_ancestors api_asset = anc_api_asset + api_asset if incl_refs: if not aggregate and not incl_descendants: use_ref_parents = False main_ref_api_asset = None ref_api_asset = to_ref_api_asset while incl_refs: content_refnames = [] content_refnames_set = set() for d in ref_api_asset: d_refname = self.parse_link_refname(d["link"]) if d_refname is not None: for link in self.parse_content_links(d["content"]): if "/api/" in link: refname = self.parse_link_refname(link) if refname is not None: if use_ref_parents and not self.is_link_module(link): refname = ".".join(refname.split(".")[:-1]) if refname not in content_refnames_set: content_refnames.append(refname) content_refnames_set.add(refname) if d_refname != refname and refname not in refname_graph: refname_graph[d_refname].append(refname) ref_refnames = [] ref_refnames_set = set(refnames_aggregated.keys()) | content_refnames_set for refname in content_refnames: if refname in refnames_aggregated and (refnames_aggregated[refname] or not aggregate_refs): continue _refname = refname while _refname: _refname = ".".join(_refname.split(".")[:-1]) if _refname in ref_refnames_set and refnames_aggregated.get(_refname, aggregate_refs): break if not _refname: ref_refnames.append(refname) if len(ref_refnames) == 0: break ref_api_asset = self.find_refname( ref_refnames, single_item=False, incl_descendants=incl_ref_descendants, aggregate=aggregate_refs, aggregate_kwargs=aggregate_kwargs, allow_empty=True, wrap=True, ) for d in ref_api_asset: refname = self.parse_link_refname(d["link"]) if refname is not None: refnames_aggregated[refname] = aggregate_refs if main_ref_api_asset is None: main_ref_api_asset = ref_api_asset else: main_ref_api_asset += ref_api_asset incl_refs -= 1 if main_ref_api_asset is not None: api_asset += main_ref_api_asset aggregated_refnames_set = set() for refname, aggregated in refnames_aggregated.items(): if aggregated: aggregated_refnames_set.add(refname) delete_indices = [] for i, d in enumerate(api_asset): refname = self.parse_link_refname(d["link"]) if refname is not None: if not refnames_aggregated[refname] and refname in aggregated_refnames_set: delete_indices.append(i) continue while refname: refname = ".".join(refname.split(".")[:-1]) if refname in aggregated_refnames_set: break if refname: delete_indices.append(i) if len(delete_indices) > 0: api_asset.delete_items(delete_indices, inplace=True) if topo_sort: from graphlib import TopologicalSorter refname_topo_graph = defaultdict(set) refname_topo_sorter = TopologicalSorter(refname_topo_graph) for parent_node, child_nodes in refname_graph.items(): for child_node in child_nodes: refname_topo_sorter.add(child_node, parent_node) refname_topo_order = refname_topo_sorter.static_order() refname_indices = {refname: [] for refname in refname_topo_order} remaining_indices = [] for i, d in enumerate(api_asset): refname = self.parse_link_refname(d["link"]) if refname is not None: while refname not in refname_indices: if not refname: break refname = ".".join(refname.split(".")[:-1]) if refname: refname_indices[refname].append(i) else: remaining_indices.append(i) else: remaining_indices.append(i) get_indices = [i for v in refname_indices.values() for i in v] + remaining_indices api_asset = api_asset.get_items(get_indices) if return_refname_graph: return api_asset, refname_graph return api_asset def find_obj_docs( self, obj: tp.MaybeList, *, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, incl_pages: tp.Optional[tp.MaybeIterable[str]] = None, excl_pages: tp.Optional[tp.MaybeIterable[str]] = None, page_find_mode: tp.Optional[str] = None, up_aggregate: tp.Optional[bool] = None, up_aggregate_th: tp.Union[None, int, float] = None, up_aggregate_pages: tp.Optional[bool] = None, aggregate: tp.Optional[bool] = None, aggregate_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybePagesAsset: """Find documentation relevant to object(s). If a link matches one of the links or link parts in `incl_pages`, it will be included, otherwise, it will be excluded if `incl_pages` is not empty. If a link matches one of the links or link parts in `excl_pages`, it will be excluded, otherwise, it will be included. Matching is done using `vectorbtpro.utils.search_.find` with `page_find_mode` used as `mode`. For example, using `excl_pages=["release-notes"]` won't search in release notes. If `up_aggregate` is True, will aggregate each set of headings into their parent if their number is greater than some threshold `up_aggregate_th`, which depends on the total number of headings in the parent. It can be an integer for absolute number or float for relative number. For example, `up_aggregate_th=2/3` means this method must find 2 headings out of 3 in order to replace it by the full parent heading/page. If `up_aggregate_pages` is True, does the same to pages. For example, if 2 tutorial pages out of 3 are matched, the whole tutorial series is used. If `aggregate` is True, aggregates any descendant headings into pages for this object and all base classes/attributes using `PagesAsset.aggregate_links`. Uses `PagesAsset.find_obj_mentions`.""" incl_pages = self.resolve_setting(incl_pages, "incl_pages") excl_pages = self.resolve_setting(excl_pages, "excl_pages") page_find_mode = self.resolve_setting(page_find_mode, "page_find_mode") up_aggregate = self.resolve_setting(up_aggregate, "up_aggregate") up_aggregate_th = self.resolve_setting(up_aggregate_th, "up_aggregate_th") up_aggregate_pages = self.resolve_setting(up_aggregate_pages, "up_aggregate_pages") aggregate = self.resolve_setting(aggregate, "aggregate") if incl_pages is None: incl_pages = () elif isinstance(incl_pages, str): incl_pages = (incl_pages,) if excl_pages is None: excl_pages = () elif isinstance(excl_pages, str): excl_pages = (excl_pages,) def _filter_func(x): if "link" not in x: return False if "/api/" in x["link"]: return False if excl_pages: for page in excl_pages: if find(page, x["link"], mode=page_find_mode): return False if incl_pages: for page in incl_pages: if find(page, x["link"], mode=page_find_mode): return True return False return True docs_asset = self.filter(_filter_func) mentions_asset = docs_asset.find_obj_mentions( obj, attr=attr, module=module, resolve=resolve, **kwargs, ) if ( isinstance(mentions_asset, PagesAsset) and len(mentions_asset) > 0 and isinstance(mentions_asset[0], dict) and "link" in mentions_asset[0] ): if up_aggregate: link_map = {d["link"]: dict(d) for d in docs_asset.data} new_links = {d["link"] for d in mentions_asset} while True: parent_map = defaultdict(list) without_parent = set() for link in new_links: if link_map[link]["parent"] is not None: parent_map[link_map[link]["parent"]].append(link) else: without_parent.add(link) _new_links = set() for parent, children in parent_map.items(): headings = set() non_headings = set() for child in children: if link_map[child]["type"].startswith("heading"): headings.add(child) else: non_headings.add(child) if up_aggregate_pages: _children = children else: _children = headings if checks.is_float(up_aggregate_th) and 0 <= abs(up_aggregate_th) <= 1: _up_aggregate_th = int(up_aggregate_th * len(link_map[parent]["children"])) elif checks.is_number(up_aggregate_th): if checks.is_float(up_aggregate_th) and not up_aggregate_th.is_integer(): raise TypeError(f"Up-aggregation threshold ({up_aggregate_th}) must be between 0 and 1") _up_aggregate_th = int(up_aggregate_th) else: raise TypeError(f"Up-aggregation threshold must be a number") if 0 < len(_children) >= _up_aggregate_th: _new_links.add(parent) else: _new_links |= headings _new_links |= non_headings if _new_links == new_links: break new_links = _new_links | without_parent return docs_asset.find_page( list(new_links), single_item=False, aggregate=aggregate, aggregate_kwargs=aggregate_kwargs, ) if aggregate: return docs_asset.aggregate_links( [d["link"] for d in mentions_asset], aggregate_kwargs=aggregate_kwargs, ) return mentions_asset def browse( self, entry_link: tp.Optional[str] = None, descendants_only: bool = False, aggregate: bool = False, aggregate_kwargs: tp.KwargsLike = None, **kwargs, ) -> Path: new_instance = self if entry_link is not None and entry_link != "/" and descendants_only: new_instance = new_instance.select_descendants(entry_link, incl_link=True) if aggregate: if aggregate_kwargs is None: aggregate_kwargs = {} new_instance = new_instance.aggregate(**aggregate_kwargs) return VBTAsset.browse(new_instance, entry_link=entry_link, **kwargs) def display( self, link: tp.Optional[str] = None, aggregate: bool = False, aggregate_kwargs: tp.KwargsLike = None, **kwargs, ) -> Path: new_instance = self if link is not None: new_instance = new_instance.find_page( link, aggregate=aggregate, aggregate_kwargs=aggregate_kwargs, ) elif aggregate: if aggregate_kwargs is None: aggregate_kwargs = {} new_instance = new_instance.aggregate(**aggregate_kwargs) return VBTAsset.display(new_instance, **kwargs) def aggregate( self: PagesAssetT, append_obj_type: tp.Optional[bool] = None, append_github_link: tp.Optional[bool] = None, ) -> PagesAssetT: """Aggregate pages. Content of each heading will be converted into markdown and concatenated into the content of the parent heading or page. Only regular pages and headings without parents will be left. If `append_obj_type` is True, will also append object type to the heading name. If `append_github_link` is True, will also append GitHub link to the heading name.""" append_obj_type = self.resolve_setting(append_obj_type, "append_obj_type") append_github_link = self.resolve_setting(append_github_link, "append_github_link") link_map = {d["link"]: dict(d) for d in self.data} top_parents = self.top_parent_links aggregated_links = set() def _aggregate_content(link): node = link_map[link] content = node["content"] if content is None: content = "" if node["type"].startswith("heading"): level = int(node["type"].split(" ")[1]) heading_markdown = "#" * level + " " + node["name"] if append_obj_type and node.get("obj_type", None) is not None: heading_markdown += f" | {node['obj_type']}" if append_github_link and node.get("github_link", None) is not None: heading_markdown += f" | [source]({node['github_link']})" if content == "": content = heading_markdown else: content = f"{heading_markdown}\n\n{content}" children = list(node["children"]) for child in list(children): if child in link_map: child_node = link_map[child] child_content = _aggregate_content(child) if child_node["type"].startswith("heading"): if child_content.startswith("# "): content = child_content else: content += f"\n\n{child_content}" children.remove(child) aggregated_links.add(child) if content != "": node["content"] = content node["children"] = children return content for top_parent in top_parents: _aggregate_content(top_parent) new_data = [link_map[link] for link in link_map if link not in aggregated_links] return self.replace(data=new_data) def select_parent(self: PagesAssetT, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select the parent page of a link.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) if d.get("parent", None): if d["parent"] in link_map: new_data.append(link_map[d["parent"]]) return self.replace(data=new_data, single_item=True) def select_children(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select the child pages of a link.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) if d.get("children", []): for child in d["children"]: if child in link_map: new_data.append(link_map[child]) return self.replace(data=new_data, single_item=False) def select_siblings(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select the sibling pages of a link.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) if d.get("parent", None): if d["parent"] in link_map: parent_d = link_map[d["parent"]] if parent_d.get("children", []): for child in parent_d["children"]: if incl_link or child != d["link"]: if child in link_map: new_data.append(link_map[child]) return self.replace(data=new_data, single_item=False) def select_descendants(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select all descendant pages of a link.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) descendants = set() stack = [d] while stack: d = stack.pop() children = d.get("children", []) for child in children: if child in link_map and child not in descendants: descendants.add(child) new_data.append(link_map[child]) stack.append(link_map[child]) return self.replace(data=new_data, single_item=False) def select_branch(self, link: str, **kwargs) -> PagesAssetT: """Select all descendant pages of a link including the link.""" return self.select_descendants(link, incl_link=True, **kwargs) def select_ancestors(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select all ancestor pages of a link.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) ancestors = set() parent = d.get("parent", None) while parent and parent in link_map: if parent in ancestors: break ancestors.add(parent) new_data.append(link_map[parent]) parent = link_map[parent].get("parent", None) return self.replace(data=new_data, single_item=False) def select_parent_page(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select parent page.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) ancestors = set() parent = d.get("parent", None) while parent and parent in link_map: if parent in ancestors: break ancestors.add(parent) new_data.append(link_map[parent]) if link_map[parent]["type"] == "page": break parent = link_map[parent].get("parent", None) return self.replace(data=new_data, single_item=False) def select_descendant_headings(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT: """Select descendant headings.""" d = self.find_page(link, wrap=False, **kwargs) link_map = {d["link"]: dict(d) for d in self.data} new_data = [] if incl_link: new_data.append(d) descendants = set() stack = [d] while stack: d = stack.pop() children = d.get("children", []) for child in children: if child in link_map and child not in descendants: if link_map[child]["type"].startswith("heading"): descendants.add(child) new_data.append(link_map[child]) stack.append(link_map[child]) return self.replace(data=new_data, single_item=False) def print_site_schema( self, append_type: bool = False, append_obj_type: bool = False, structure_fragments: bool = True, split_fragments: bool = True, **dir_tree_kwargs, ) -> None: """Print site schema. If `structure_fragments` is True, builds a hierarchy of fragments. Otherwise, displays them on the same level. If `split_fragments` is True, displays fragments as continuation of their parents. Otherwise, displays them in full length. Keyword arguments are split between `KnowledgeAsset.describe` and `vectorbtpro.utils.path_.dir_tree_from_paths`.""" link_map = {d["link"]: dict(d) for d in self.data} links = [] for link, d in link_map.items(): if not structure_fragments: links.append(link) continue x = d link_base = None link_fragments = [] while x["type"].startswith("heading") and "#" in x["link"]: link_parts = x["link"].split("#") if link_base is None: link_base = link_parts[0] link_fragments.append("#" + link_parts[1]) if not x.get("parent", None) or x["parent"] not in link_map: if x["type"].startswith("heading"): level = int(x["type"].split()[1]) for i in range(level - 1): link_fragments.append("?") break x = link_map[x["parent"]] if link_base is None: links.append(link) else: if split_fragments and len(link_fragments) > 1: link_fragments = link_fragments[::-1] new_link_fragments = [link_fragments[0]] for i in range(1, len(link_fragments)): link_fragment1 = link_fragments[i - 1] link_fragment2 = link_fragments[i] if link_fragment2.startswith(link_fragment1 + "."): new_link_fragments.append("." + link_fragment2[len(link_fragment1 + ".") :]) else: new_link_fragments.append(link_fragment2) link_fragments = new_link_fragments links.append(link_base + "/".join(link_fragments)) paths = self.links_to_paths(links, allow_fragments=not structure_fragments) path_names = [] for i, d in enumerate(link_map.values()): path_name = paths[i].name brackets = [] if append_type: brackets.append(d["type"]) if append_obj_type and d["obj_type"]: brackets.append(d["obj_type"]) if brackets: path_name += f" [{', '.join(brackets)}]" path_names.append(path_name) if "root_name" not in dir_tree_kwargs: root_name = get_common_prefix(link_map.keys()) if not root_name: root_name = "/" dir_tree_kwargs["root_name"] = root_name if "sort" not in dir_tree_kwargs: dir_tree_kwargs["sort"] = False if "path_names" not in dir_tree_kwargs: dir_tree_kwargs["path_names"] = path_names if "length_limit" not in dir_tree_kwargs: dir_tree_kwargs["length_limit"] = None print(dir_tree_from_paths(paths, **dir_tree_kwargs)) MessagesAssetT = tp.TypeVar("MessagesAssetT", bound="MessagesAsset") class MessagesAsset(VBTAsset): """Class for working with Discord messages. Each message has the following fields: link: URL of the message, such as "https://discord.com/channels/918629562441695344/919715148896301067/923327319882485851" block: URL of the first message in the block. A block is a bunch of messages of the same author that either reference a message of another author, or don't reference any message at all. thread: URL of the first message in the thread. A thread is a bunch of blocks that reference each other in a chain, such as questions, answers, follow-up questions, etc. reference: URL of the message that the message references. Can be None. replies: List of URLs of the messages that reference the message channel: Channel of the message, such as "support" timestamp: Timestamp of the message, such as "2024-01-01 00:00:00" author: Author of the message, such as "@polakowo" content: String content of the message mentions: List of Discord usernames that this message mentions, such as ["@polakowo"] attachments: List of attachments. Each attachment has two fields: "file_name", such as "some_image.png", and "content" containing the string content extracted from the file. reactions: Total number of reactions that this message has received For defaults, see `assets.messages` in `vectorbtpro._settings.knowledge`.""" _settings_path: tp.SettingsPath = "knowledge.assets.messages" def latest_first(self, **kwargs) -> tp.MaybeMessagesAsset: """Sort by timestamp in descending order.""" return self.sort(keys=self.get("timestamp"), ascending=False, **kwargs) def aggregate_messages( self: MessagesAssetT, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeMessagesAsset: """Aggregate attachments by message. For keyword arguments, see `MessagesAsset.to_markdown`. Uses `MessagesAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.AggMessageAssetFunc`.""" return self.apply( "agg_message", minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, **kwargs, ) def aggregate_blocks( self: MessagesAssetT, collect_kwargs: tp.KwargsLike = None, aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None, parent_links_only: tp.Optional[bool] = None, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeMessagesAsset: """Aggregate messages by block. First, uses `MessagesAsset.reduce` on `vectorbtpro.utils.knowledge.base_asset_funcs.CollectAssetFunc` to collect data items by the field "block". Keyword arguments in `collect_kwargs` are passed here. Argument `uniform_groups` is True by default. Then, uses `MessagesAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.AggBlockAssetFunc` to aggregate each collected data item. Use `aggregate_fields` to provide a set of fields to be aggregated rather than used in child metadata. It can be True to aggregate all lists and False to aggregate none. If `parent_links_only` is True, doesn't include links in the metadata of each message. For other keyword arguments, see `MessagesAsset.to_markdown`.""" if collect_kwargs is None: collect_kwargs = {} if "uniform_groups" not in collect_kwargs: collect_kwargs["uniform_groups"] = True instance = self.collect(by="block", wrap=True, **collect_kwargs) return instance.apply( "agg_block", aggregate_fields=aggregate_fields, parent_links_only=parent_links_only, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, link_map={d["link"]: dict(d) for d in self.data}, **kwargs, ) def aggregate_threads( self: MessagesAssetT, collect_kwargs: tp.KwargsLike = None, aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None, parent_links_only: tp.Optional[bool] = None, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeMessagesAsset: """Aggregate messages by thread. Same as `MessagesAsset.aggregate_blocks` but for threads. Uses `vectorbtpro.utils.knowledge.custom_asset_funcs.AggThreadAssetFunc`.""" if collect_kwargs is None: collect_kwargs = {} if "uniform_groups" not in collect_kwargs: collect_kwargs["uniform_groups"] = True instance = self.collect(by="thread", wrap=True, **collect_kwargs) return instance.apply( "agg_thread", aggregate_fields=aggregate_fields, parent_links_only=parent_links_only, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, link_map={d["link"]: dict(d) for d in self.data}, **kwargs, ) def aggregate_channels( self: MessagesAssetT, collect_kwargs: tp.KwargsLike = None, aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None, parent_links_only: tp.Optional[bool] = None, minimize_metadata: tp.Optional[bool] = None, minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None, clean_metadata: tp.Optional[bool] = None, clean_metadata_kwargs: tp.KwargsLike = None, dump_metadata_kwargs: tp.KwargsLike = None, to_markdown_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeMessagesAsset: """Aggregate messages by channel. Same as `MessagesAsset.aggregate_threads` but for channels. Uses `vectorbtpro.utils.knowledge.custom_asset_funcs.AggChannelAssetFunc`.""" if collect_kwargs is None: collect_kwargs = {} if "uniform_groups" not in collect_kwargs: collect_kwargs["uniform_groups"] = True instance = self.collect(by="channel", wrap=True, **collect_kwargs) return instance.apply( "agg_channel", aggregate_fields=aggregate_fields, parent_links_only=parent_links_only, minimize_metadata=minimize_metadata, minimize_keys=minimize_keys, clean_metadata=clean_metadata, clean_metadata_kwargs=clean_metadata_kwargs, dump_metadata_kwargs=dump_metadata_kwargs, to_markdown_kwargs=to_markdown_kwargs, link_map={d["link"]: dict(d) for d in self.data}, **kwargs, ) @property def lowest_aggregate_by(self) -> tp.Optional[str]: """Get the lowest level that aggregates all messages.""" try: if self.get("attachments"): return "message" except KeyError: pass try: if len(set(self.get("block"))) == 1: return "block" except KeyError: pass try: if len(set(self.get("thread"))) == 1: return "thread" except KeyError: pass try: if len(set(self.get("channel"))) == 1: return "channel" except KeyError: pass @property def highest_aggregate_by(self) -> tp.Optional[str]: """Get the highest level that aggregates all messages.""" try: if len(set(self.get("channel"))) == 1: return "channel" except KeyError: pass try: if len(set(self.get("thread"))) == 1: return "thread" except KeyError: pass try: if len(set(self.get("block"))) == 1: return "block" except KeyError: pass try: if self.get("attachments"): return "message" except KeyError: pass def aggregate(self, by: str = "lowest", **kwargs) -> tp.MaybeMessagesAsset: """Aggregate by "message" (attachments), "block", "thread", or "channel". If `by` is None, uses `MessagesAsset.lowest_aggregate_by`.""" if by.lower() == "lowest": by = self.lowest_aggregate_by elif by.lower() == "highest": by = self.highest_aggregate_by if by is None: raise ValueError("Must provide by") if not by.lower().endswith("s"): by += "s" return getattr(self, "aggregate_" + by.lower())(**kwargs) def select_reference(self: MessagesAssetT, link: str, **kwargs) -> MessagesAssetT: """Select the reference message.""" d = self.find_link(link, wrap=False, **kwargs) reference = d.get("reference", None) new_data = [] if reference: for d2 in self.data: if d2["reference"] == reference: new_data.append(d2) break return self.replace(data=new_data, single_item=True) def select_replies(self: MessagesAssetT, link: str, **kwargs) -> MessagesAssetT: """Select the reply messages.""" d = self.find_link(link, wrap=False, **kwargs) replies = d.get("replies", []) new_data = [] if replies: reply_data = {reply: None for reply in replies} replies_found = 0 for d2 in self.data: if d2["link"] in reply_data: reply_data[d2["link"]] = d2 replies_found += 1 if replies_found == len(replies): break new_data = list(reply_data.values()) return self.replace(data=new_data, single_item=True) def select_block(self: MessagesAssetT, link: str, incl_link: bool = True, **kwargs) -> MessagesAssetT: """Select the messages that belong to the block of a link.""" d = self.find_link(link, wrap=False, **kwargs) new_data = [] for d2 in self.data: if d2["block"] == d["block"] and (incl_link or d2["link"] != d["link"]): new_data.append(d2) return self.replace(data=new_data, single_item=False) def select_thread(self: MessagesAssetT, link: str, incl_link: bool = True, **kwargs) -> MessagesAssetT: """Select the messages that belong to the thread of a link.""" d = self.find_link(link, wrap=False, **kwargs) new_data = [] for d2 in self.data: if d2["thread"] == d["thread"] and (incl_link or d2["link"] != d["link"]): new_data.append(d2) return self.replace(data=new_data, single_item=False) def select_channel(self: MessagesAssetT, link: str, incl_link: bool = True, **kwargs) -> MessagesAssetT: """Select the messages that belong to the channel of a link.""" d = self.find_link(link, wrap=False, **kwargs) new_data = [] for d2 in self.data: if d2["channel"] == d["channel"] and (incl_link or d2["link"] != d["link"]): new_data.append(d2) return self.replace(data=new_data, single_item=False) def find_obj_messages( self, obj: tp.MaybeList, *, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, **kwargs, ) -> tp.MaybeMessagesAsset: """Find messages relevant to object(s). Uses `MessagesAsset.find_obj_mentions`.""" return self.find_obj_mentions(obj, attr=attr, module=module, resolve=resolve, **kwargs) def is_obj_or_query_ref(obj_or_query: tp.MaybeList) -> bool: """Return whether `obj_or_query` is a reference to an object.""" if isinstance(obj_or_query, str): return all(segment.isidentifier() for segment in obj_or_query.split(".")) return True def find_api( obj_or_query: tp.Optional[tp.MaybeList] = None, *, as_query: tp.Optional[bool] = None, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None, pull_kwargs: tp.KwargsLike = None, aggregate: bool = False, aggregate_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybePagesAsset: """Find API pages and headings relevant to object(s) or a query. If `obj_or_query` is None, returns all API pages. If it's a reference to an object, uses `PagesAsset.find_obj_api`. Otherwise, uses `PagesAsset.rank`.""" if pages_asset is None: pages_asset = PagesAsset if isinstance(pages_asset, type): checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset") if pull_kwargs is None: pull_kwargs = {} pages_asset = pages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset") if aggregate: if aggregate_kwargs is None: aggregate_kwargs = {} pages_asset = pages_asset.aggregate(**aggregate_kwargs) if as_query is None: as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query) if obj_or_query is not None and not as_query: return pages_asset.find_obj_api(obj_or_query, attr=attr, module=module, resolve=resolve, **kwargs) pages_asset = pages_asset.filter(lambda x: "link" in x and "/api/" in x["link"]) if obj_or_query is None: return pages_asset return pages_asset.rank(obj_or_query, **kwargs) def find_docs( obj_or_query: tp.Optional[tp.MaybeList] = None, *, as_query: tp.Optional[bool] = None, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None, pull_kwargs: tp.KwargsLike = None, aggregate: bool = False, aggregate_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybePagesAsset: """Find documentation pages and headings relevant to object(s) or a query. If `obj_or_query` is None, returns all documentation pages. If it's a reference to an object, uses `PagesAsset.find_obj_docs`. Otherwise, uses `PagesAsset.rank`.""" if pages_asset is None: pages_asset = PagesAsset if isinstance(pages_asset, type): checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset") if pull_kwargs is None: pull_kwargs = {} pages_asset = pages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset") if aggregate: if aggregate_kwargs is None: aggregate_kwargs = {} pages_asset = pages_asset.aggregate(**aggregate_kwargs) if as_query is None: as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query) if obj_or_query is not None and not as_query: return pages_asset.find_obj_docs(obj_or_query, attr=attr, module=module, resolve=resolve, **kwargs) pages_asset = pages_asset.filter(lambda x: "link" in x and "/api/" not in x["link"]) if obj_or_query is None: return pages_asset return pages_asset.rank(obj_or_query, **kwargs) def find_messages( obj_or_query: tp.Optional[tp.MaybeList] = None, *, as_query: tp.Optional[bool] = None, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, messages_asset: tp.Optional[tp.MaybeType[MessagesAssetT]] = None, pull_kwargs: tp.KwargsLike = None, aggregate: tp.Union[bool, str] = "messages", aggregate_kwargs: tp.KwargsLike = None, latest_first: bool = False, shuffle: bool = False, **kwargs, ) -> tp.MaybeMessagesAsset: """Find messages relevant to object(s) or a query. If `obj_or_query` is None, returns all messages. If it's a reference to an object, uses `MessagesAsset.find_obj_messages`. Otherwise, uses `MessagesAsset.rank`.""" if messages_asset is None: messages_asset = MessagesAsset if isinstance(messages_asset, type): checks.assert_subclass_of(messages_asset, MessagesAsset, arg_name="messages_asset") if pull_kwargs is None: pull_kwargs = {} messages_asset = messages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(messages_asset, MessagesAsset, arg_name="messages_asset") if aggregate: if aggregate_kwargs is None: aggregate_kwargs = {} if isinstance(aggregate, str) and "by" not in aggregate_kwargs: aggregate_kwargs["by"] = aggregate messages_asset = messages_asset.aggregate(**aggregate_kwargs) if latest_first: messages_asset = messages_asset.latest_first() elif shuffle: messages_asset = messages_asset.shuffle() if as_query is None: as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query) if obj_or_query is not None and not as_query: return messages_asset.find_obj_messages(obj_or_query, attr=attr, module=module, resolve=resolve, **kwargs) if obj_or_query is None: return messages_asset return messages_asset.rank(obj_or_query, **kwargs) def find_examples( obj_or_query: tp.Optional[tp.MaybeList] = None, *, as_query: tp.Optional[bool] = None, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, as_code: bool = True, return_type: tp.Optional[str] = "field", pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None, messages_asset: tp.Optional[tp.MaybeType[MessagesAssetT]] = None, pull_kwargs: tp.KwargsLike = None, aggregate_pages: bool = False, aggregate_pages_kwargs: tp.KwargsLike = None, aggregate_messages: tp.Union[bool, str] = "messages", aggregate_messages_kwargs: tp.KwargsLike = None, latest_messages_first: bool = False, shuffle_messages: bool = False, find_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeVBTAsset: """Find (code) examples relevant to object(s) or a query. If `obj_or_query` is None, returns all examples with `VBTAsset.find_code` or `VBTAsset.find`. If it's a reference to an object, uses `VBTAsset.find_obj_mentions`. Otherwise, uses `VBTAsset.find_code` or `VBTAsset.find` and then `VBTAsset.rank`. Keyword arguments are distributed among these methods automatically, unless some keys cannot be found in both signatures. In such a case, the key will be used for ranking. If this is not wanted, specify `find_kwargs`. By default, extracts code with text. Use `return_type="match"` to extract code without text, or, for instance, `return_type="item"` to also get links.""" if pages_asset is None: pages_asset = PagesAsset if isinstance(pages_asset, type): checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset") if pull_kwargs is None: pull_kwargs = {} pages_asset = pages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset") if aggregate_pages: if aggregate_pages_kwargs is None: aggregate_pages_kwargs = {} pages_asset = pages_asset.aggregate(**aggregate_pages_kwargs) if messages_asset is None: messages_asset = MessagesAsset if isinstance(messages_asset, type): checks.assert_subclass_of(messages_asset, MessagesAsset, arg_name="messages_asset") if pull_kwargs is None: pull_kwargs = {} messages_asset = messages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(messages_asset, MessagesAsset, arg_name="messages_asset") if aggregate_messages: if aggregate_messages_kwargs is None: aggregate_messages_kwargs = {} if isinstance(aggregate_messages, str) and "by" not in aggregate_messages_kwargs: aggregate_messages_kwargs["by"] = aggregate_messages messages_asset = messages_asset.aggregate(**aggregate_messages_kwargs) if latest_messages_first: messages_asset = messages_asset.latest_first() elif shuffle_messages: messages_asset = messages_asset.shuffle() combined_asset = pages_asset + messages_asset if as_query is None: as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query) if find_kwargs is None: find_kwargs = {} else: find_kwargs = dict(find_kwargs) find_kwargs["return_type"] = return_type if obj_or_query is not None and not as_query: return combined_asset.find_obj_mentions( obj_or_query, attr=attr, module=module, resolve=resolve, as_code=as_code, **find_kwargs, **kwargs, ) if as_code: method = combined_asset.find_code else: method = combined_asset.find if obj_or_query is None: return method(**find_kwargs, **kwargs) find_code_arg_names = set(get_func_arg_names(method)) rank_kwargs = {} for k, v in kwargs.items(): if k in find_code_arg_names: if k not in find_kwargs: find_kwargs[k] = v else: rank_kwargs[k] = v return method(**find_kwargs).rank(obj_or_query, **rank_kwargs) def find_assets( obj_or_query: tp.Optional[tp.MaybeList] = None, *, as_query: tp.Optional[bool] = None, attr: tp.Optional[str] = None, module: tp.Union[None, str, ModuleType] = None, resolve: bool = True, asset_names: tp.Optional[tp.MaybeIterable[str]] = None, pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None, messages_asset: tp.Optional[tp.MaybeType[MessagesAssetT]] = None, pull_kwargs: tp.KwargsLike = None, aggregate_pages: bool = False, aggregate_pages_kwargs: tp.KwargsLike = None, aggregate_messages: tp.Union[bool, str] = "messages", aggregate_messages_kwargs: tp.KwargsLike = None, latest_messages_first: bool = False, shuffle_messages: bool = False, api_kwargs: tp.KwargsLike = None, docs_kwargs: tp.KwargsLike = None, messages_kwargs: tp.KwargsLike = None, examples_kwargs: tp.KwargsLike = None, minimize: tp.Optional[bool] = None, minimize_pages: tp.Optional[bool] = None, minimize_messages: tp.Optional[bool] = None, minimize_kwargs: tp.KwargsLike = None, minimize_pages_kwargs: tp.KwargsLike = None, minimize_messages_kwargs: tp.KwargsLike = None, combine: bool = True, combine_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeDict[tp.VBTAsset]: """Find all assets relevant to object(s) or a query. Argument `asset_names` can be a list of asset names in any order. It defaults to "api", "docs", and "messages", It can also include ellipsis (`...`). For example, `["messages", ...]` puts "messages" at the beginning and all other assets in their usual order at the end. The following asset names are supported: * "api": `find_api` with `api_kwargs` * "docs": `find_docs` with `docs_kwargs` * "messages": `find_messages` with `messages_kwargs` * "examples": `find_examples` with `examples_kwargs` * "all": All of the above !!! note Examples usually overlap with other assets, thus they are excluded by default. Set `combine` to True to combine all assets into a single asset. Uses `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.combine` with `combine_kwargs`. If `obj_or_query` is a query, will rank the combined asset. Otherwise, will rank each individual asset. Set `minimize` to True (or `minimize_pages` for pages and `minimize_messages` for messages) in order to minimize to remove fields that aren't relevant for chatting. It defaults to True if `combine` is True, otherwise, it defaults to False. Uses `VBTAsset.minimize` with `minimize_kwargs`, `PagesAsset.minimize` with `minimize_pages_kwargs`, and `MessagesAsset.minimize` with `minimize_messages_kwargs`. Arguments `minimize_pages_kwargs` and `minimize_messages_kwargs` are merged over `minimize_kwargs`. Keyword arguments are passed to all functions (except for `find_api` when `obj_or_query` is an object since it doesn't share common arguments with other three functions), unless `combine` and `as_query` are both True; in this case they are passed to `VBTAsset.rank`. Use specialized arguments like `api_kwargs` to provide keyword arguments to the respective function.""" if pages_asset is None: pages_asset = PagesAsset if isinstance(pages_asset, type): checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset") if pull_kwargs is None: pull_kwargs = {} pages_asset = pages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset") if aggregate_pages: if aggregate_pages_kwargs is None: aggregate_pages_kwargs = {} pages_asset = pages_asset.aggregate(**aggregate_pages_kwargs) if messages_asset is None: messages_asset = MessagesAsset if isinstance(messages_asset, type): checks.assert_subclass_of(messages_asset, MessagesAsset, arg_name="messages_asset") if pull_kwargs is None: pull_kwargs = {} messages_asset = messages_asset.pull(**pull_kwargs) else: checks.assert_instance_of(messages_asset, MessagesAsset, arg_name="messages_asset") if aggregate_messages: if aggregate_messages_kwargs is None: aggregate_messages_kwargs = {} if isinstance(aggregate_messages, str) and "by" not in aggregate_messages_kwargs: aggregate_messages_kwargs["by"] = aggregate_messages messages_asset = messages_asset.aggregate(**aggregate_messages_kwargs) if latest_messages_first: messages_asset = messages_asset.latest_first() elif shuffle_messages: messages_asset = messages_asset.shuffle() if as_query is None: as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query) if combine and as_query and obj_or_query is not None: if api_kwargs is None: api_kwargs = {} if docs_kwargs is None: docs_kwargs = {} if messages_kwargs is None: messages_kwargs = {} if examples_kwargs is None: examples_kwargs = {} else: if as_query: api_kwargs = merge_dicts(kwargs, api_kwargs) else: if api_kwargs is None: api_kwargs = {} docs_kwargs = merge_dicts(kwargs, docs_kwargs) messages_kwargs = merge_dicts(kwargs, messages_kwargs) examples_kwargs = merge_dicts(kwargs, examples_kwargs) asset_dict = {} all_asset_names = ["api", "docs", "messages", "examples"] if asset_names is not None: if isinstance(asset_names, str) and asset_names.lower() == "all": asset_names = all_asset_names else: if isinstance(asset_names, (str, type(Ellipsis))): asset_names = [asset_names] asset_keys = [] for asset_name in asset_names: if asset_name is not Ellipsis: asset_key = all_asset_names.index(asset_name.lower()) if asset_key == -1: raise ValueError(f"Invalid asset name: '{asset_name}'") asset_keys.append(asset_key) else: asset_keys.append(Ellipsis) new_asset_names = reorder_list(all_asset_names, asset_keys, skip_missing=True) if "examples" not in asset_names and "examples" in new_asset_names: new_asset_names.remove("examples") asset_names = new_asset_names else: asset_names = ["api", "docs", "messages"] for asset_name in asset_names: if asset_name == "api": asset = find_api( None if combine and as_query else obj_or_query, as_query=as_query, attr=attr, module=module, resolve=resolve, pages_asset=pages_asset, aggregate=False, **api_kwargs, ) if len(asset) > 0: asset_dict[asset_name] = asset elif asset_name == "docs": asset = find_docs( None if combine and as_query else obj_or_query, as_query=as_query, attr=attr, module=module, resolve=resolve, pages_asset=pages_asset, aggregate=False, **docs_kwargs, ) if len(asset) > 0: asset_dict[asset_name] = asset elif asset_name == "messages": asset = find_messages( None if combine and as_query else obj_or_query, as_query=as_query, attr=attr, module=module, resolve=resolve, messages_asset=messages_asset, aggregate=False, latest_first=False, **messages_kwargs, ) if len(asset) > 0: asset_dict[asset_name] = asset elif asset_name == "examples": if examples_kwargs is None: examples_kwargs = {} asset = find_examples( None if combine and as_query else obj_or_query, as_query=as_query, attr=attr, module=module, resolve=resolve, pages_asset=pages_asset, messages_asset=messages_asset, aggregate_messages=False, aggregate_pages=False, latest_messages_first=False, **examples_kwargs, ) if len(asset) > 0: asset_dict[asset_name] = asset if minimize is None: minimize = combine and not as_query if minimize: if minimize_kwargs is None: minimize_kwargs = {} for k, v in asset_dict.items(): if ( isinstance(v, VBTAsset) and not isinstance(v, (PagesAsset, MessagesAsset)) and len(v) > 0 and not isinstance(v[0], str) ): asset_dict[k] = v.minimize(**minimize_kwargs) if minimize_pages is None: minimize_pages = minimize if minimize_pages: minimize_pages_kwargs = merge_dicts(minimize_kwargs, minimize_pages_kwargs) for k, v in asset_dict.items(): if isinstance(v, PagesAsset) and len(v) > 0 and not isinstance(v[0], str): asset_dict[k] = v.minimize(**minimize_pages_kwargs) if minimize_messages is None: minimize_messages = minimize if minimize_messages: minimize_messages_kwargs = merge_dicts(minimize_kwargs, minimize_messages_kwargs) for k, v in asset_dict.items(): if isinstance(v, MessagesAsset) and len(v) > 0 and not isinstance(v[0], str): asset_dict[k] = v.minimize(**minimize_messages_kwargs) if combine: if len(asset_dict) >= 2: if combine_kwargs is None: combine_kwargs = {} combined_asset = VBTAsset.combine(*asset_dict.values(), **combine_kwargs) elif len(asset_dict) == 1: combined_asset = list(asset_dict.values())[0] else: combined_asset = VBTAsset() if combined_asset and as_query and obj_or_query is not None: combined_asset = combined_asset.rank(obj_or_query, **kwargs) return combined_asset return asset_dict def chat_about( obj: tp.MaybeList, message: str, chat_history: tp.ChatHistory = None, *, asset_names: tp.Optional[tp.MaybeIterable[str]] = "examples", latest_messages_first: bool = True, shuffle_messages: tp.Optional[bool] = None, shuffle: tp.Optional[bool] = None, find_assets_kwargs: tp.KwargsLike = None, **kwargs, ) -> tp.MaybeChatOutput: """Chat about object(s). By default, uses examples only. Uses `find_assets` with `combine=True` and `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.chat`. Keyword arguments are distributed among these two methods automatically, unless some keys cannot be found in both signatures. In such a case, the key will be used for chatting. If this is not wanted, specify the `find_assets`-related arguments explicitly with `find_assets_kwargs`. If `shuffle` is True, shuffles the combined asset. By default, shuffles only messages (`shuffle=False` and `shuffle_messages=True`). If `shuffle` is False, shuffles neither messages nor combined asset.""" if shuffle is not None: if shuffle_messages is None: shuffle_messages = False else: shuffle = False if shuffle_messages is None: shuffle_messages = True find_arg_names = set(get_func_arg_names(find_assets)) if find_assets_kwargs is None: find_assets_kwargs = {} else: find_assets_kwargs = dict(find_assets_kwargs) chat_kwargs = {} for k, v in kwargs.items(): if k in find_arg_names: if k not in find_assets_kwargs: find_assets_kwargs[k] = v else: chat_kwargs[k] = v asset = find_assets( obj, as_query=False, asset_names=asset_names, combine=True, latest_messages_first=latest_messages_first, shuffle_messages=shuffle_messages, **find_assets_kwargs, ) if shuffle: asset = asset.shuffle() return asset.chat(message, chat_history, **chat_kwargs) def search( query: str, cache_documents: bool = True, cache_key: tp.Optional[str] = None, asset_cache_manager: tp.Optional[tp.MaybeType[AssetCacheManager]] = None, asset_cache_manager_kwargs: tp.KwargsLike = None, aggregate_messages: tp.Union[bool, str] = "threads", aggregate_messages_kwargs: tp.KwargsLike = None, find_assets_kwargs: tp.KwargsLike = None, display: tp.Union[bool, int] = 20, display_kwargs: tp.KwargsLike = None, silence_warnings: bool = False, **kwargs, ) -> tp.Union[tp.MaybeVBTAsset, tp.Path]: """Search for a query. By default, uses API, documentation, and messages. Uses `find_assets` with `combine=True` and `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.rank`. Keyword arguments are distributed among these two methods automatically, unless some keys cannot be found in both signatures. In such a case, the key will be used for ranking. If this is not wanted, specify the `find_assets`-related arguments explicitly with `find_assets_kwargs`. If `display` is True, displays the top results as static HTML pages with `VBTAsset.display`. Pass an integer to display n top results. Will return the path to the temporary file. Metadata when aggregating messages will be minimized by default. If `cache_documents` is True, will use an asset cache manager to store the generated text documents in a local and/or disk cache after conversion. Running the same method again will use the cached documents.""" find_arg_names = set(get_func_arg_names(find_assets)) if find_assets_kwargs is None: find_assets_kwargs = {} else: find_assets_kwargs = dict(find_assets_kwargs) rank_kwargs = {} for k, v in kwargs.items(): if k in find_arg_names: if k not in find_assets_kwargs: find_assets_kwargs[k] = v else: rank_kwargs[k] = v find_assets_kwargs["aggregate_messages"] = aggregate_messages if aggregate_messages_kwargs is None: aggregate_messages_kwargs = {} else: aggregate_messages_kwargs = dict(aggregate_messages_kwargs) if "minimize_metadata" not in aggregate_messages_kwargs: aggregate_messages_kwargs["minimize_metadata"] = True find_assets_kwargs["aggregate_messages_kwargs"] = aggregate_messages_kwargs if cache_documents: if asset_cache_manager is None: asset_cache_manager = AssetCacheManager if asset_cache_manager_kwargs is None: asset_cache_manager_kwargs = {} if isinstance(asset_cache_manager, type): checks.assert_subclass_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager") asset_cache_manager = asset_cache_manager(**asset_cache_manager_kwargs) else: checks.assert_instance_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager") if asset_cache_manager_kwargs: asset_cache_manager = asset_cache_manager.replace(**asset_cache_manager_kwargs) asset_cache_manager_kwargs = {} if cache_key is None: cache_key = asset_cache_manager.generate_cache_key(**find_assets_kwargs) asset = asset_cache_manager.load_asset(cache_key) if asset is None: if not silence_warnings: warn("Caching documents...") silence_warnings = True else: asset = None if asset is None: asset = find_assets(None, as_query=True, **find_assets_kwargs) found_asset = asset.rank( query, cache_documents=cache_documents, cache_key=cache_key, asset_cache_manager=asset_cache_manager, asset_cache_manager_kwargs=asset_cache_manager_kwargs, silence_warnings=silence_warnings, **rank_kwargs, ) if display: if display_kwargs is None: display_kwargs = {} else: display_kwargs = dict(display_kwargs) if "title" not in display_kwargs: display_kwargs["title"] = query if isinstance(display, bool): display_asset = found_asset else: display_asset = found_asset[:display] return display_asset.display(**display_kwargs) return found_asset def chat( query: str, chat_history: tp.ChatHistory = None, *, cache_documents: bool = True, cache_key: tp.Optional[str] = None, asset_cache_manager: tp.Optional[tp.MaybeType[AssetCacheManager]] = None, asset_cache_manager_kwargs: tp.KwargsLike = None, aggregate_messages: tp.Union[bool, str] = "threads", aggregate_messages_kwargs: tp.KwargsLike = None, find_assets_kwargs: tp.KwargsLike = None, rank: tp.Optional[bool] = True, top_k: tp.TopKLike = "elbow", min_top_k: tp.TopKLike = 20, max_top_k: tp.TopKLike = 100, cutoff: tp.Optional[float] = None, return_chunks: tp.Optional[bool] = True, rank_kwargs: tp.KwargsLike = None, wrap_documents: tp.Optional[bool] = True, silence_warnings: bool = False, **kwargs, ) -> tp.MaybeChatOutput: """Chat about a query. By default, uses API, documentation, and messages. Uses `find_assets` with `obj_or_query=None`, `as_query=True`, and `combine=True`, and `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.chat`. Keyword arguments are distributed among these two methods automatically, unless some keys cannot be found in both signatures. In such a case, the key will be used for chatting. If this is not wanted, specify the `find_assets`- related arguments explicitly with `find_assets_kwargs`. Metadata when aggregating messages will be minimized by default. If `cache_documents` is True, will use an asset cache manager to store the generated text documents in a local and/or disk cache after conversion. Running the same method again will use the cached documents.""" find_arg_names = set(get_func_arg_names(find_assets)) if find_assets_kwargs is None: find_assets_kwargs = {} else: find_assets_kwargs = dict(find_assets_kwargs) chat_kwargs = {} for k, v in kwargs.items(): if k in find_arg_names: if k not in find_assets_kwargs: find_assets_kwargs[k] = v else: chat_kwargs[k] = v find_assets_kwargs["aggregate_messages"] = aggregate_messages if aggregate_messages_kwargs is None: aggregate_messages_kwargs = {} else: aggregate_messages_kwargs = dict(aggregate_messages_kwargs) if "minimize_metadata" not in aggregate_messages_kwargs: aggregate_messages_kwargs["minimize_metadata"] = True find_assets_kwargs["aggregate_messages_kwargs"] = aggregate_messages_kwargs if cache_documents: if asset_cache_manager is None: asset_cache_manager = AssetCacheManager if asset_cache_manager_kwargs is None: asset_cache_manager_kwargs = {} if isinstance(asset_cache_manager, type): checks.assert_subclass_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager") asset_cache_manager = asset_cache_manager(**asset_cache_manager_kwargs) else: checks.assert_instance_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager") if asset_cache_manager_kwargs: asset_cache_manager = asset_cache_manager.replace(**asset_cache_manager_kwargs) asset_cache_manager_kwargs = {} if cache_key is None: cache_key = asset_cache_manager.generate_cache_key(**find_assets_kwargs) asset = asset_cache_manager.load_asset(cache_key) if asset is None: if not silence_warnings: warn("Caching documents...") silence_warnings = True else: asset = None if asset is None: asset = find_assets(None, as_query=True, **find_assets_kwargs) if rank_kwargs is None: rank_kwargs = {} else: rank_kwargs = dict(rank_kwargs) rank_kwargs["cache_documents"] = cache_documents rank_kwargs["cache_key"] = cache_key rank_kwargs["asset_cache_manager"] = asset_cache_manager rank_kwargs["asset_cache_manager_kwargs"] = asset_cache_manager_kwargs rank_kwargs["silence_warnings"] = silence_warnings if "wrap_documents" not in rank_kwargs: rank_kwargs["wrap_documents"] = wrap_documents return asset.chat( query, chat_history, rank=rank, top_k=top_k, min_top_k=min_top_k, max_top_k=max_top_k, cutoff=cutoff, return_chunks=return_chunks, rank_kwargs=rank_kwargs, **chat_kwargs, ) # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Classes for content formatting. See `vectorbtpro.utils.knowledge` for the toy dataset.""" import re import inspect import time import sys from pathlib import Path from vectorbtpro import _typing as tp from vectorbtpro.utils import checks from vectorbtpro.utils.config import Configured, flat_merge_dicts from vectorbtpro.utils.module_ import get_caller_qualname from vectorbtpro.utils.path_ import check_mkdir from vectorbtpro.utils.template import CustomTemplate, SafeSub, RepFunc try: if not tp.TYPE_CHECKING: raise ImportError from IPython.display import DisplayHandle as DisplayHandleT except ImportError: DisplayHandleT = "DisplayHandle" __all__ = [ "ContentFormatter", "PlainFormatter", "IPythonFormatter", "IPythonMarkdownFormatter", "IPythonHTMLFormatter", "HTMLFileFormatter", ] class ToMarkdown(Configured): """Class to convert text to Markdown.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.formatting"] def __init__( self, remove_code_title: tp.Optional[bool] = None, even_indentation: tp.Optional[bool] = None, newline_before_list: tp.Optional[bool] = None, **kwargs, ) -> None: Configured.__init__( self, remove_code_title=remove_code_title, even_indentation=even_indentation, newline_before_list=newline_before_list, **kwargs, ) remove_code_title = self.resolve_setting(remove_code_title, "remove_code_title") even_indentation = self.resolve_setting(even_indentation, "even_indentation") newline_before_list = self.resolve_setting(newline_before_list, "newline_before_list") self._remove_code_title = remove_code_title self._even_indentation = even_indentation self._newline_before_list = newline_before_list @property def remove_code_title(self) -> bool: """Whether to remove `title` attribute from a code block and puts it above it.""" return self._remove_code_title @property def newline_before_list(self) -> bool: """Whether to add a new line before a list.""" return self._newline_before_list @property def even_indentation(self) -> bool: """Whether to make leading spaces even. For example, 3 leading spaces become 4.""" return self._even_indentation def to_markdown(self, text: str) -> str: """Convert text to Markdown.""" markdown = text if self.remove_code_title: def _replace_code_block(match): language = match.group(1) title = match.group(2) code = match.group(3) if title: title_md = f"**{title}**\n\n" else: title_md = "" code_md = f"```{language}\n{code}\n```" return title_md + code_md code_block_pattern = re.compile(r'```(\w+)\s+title="([^"]*)"\s*\n(.*?)\n```', re.DOTALL) markdown = code_block_pattern.sub(_replace_code_block, markdown) if self.even_indentation: leading_spaces_pattern = re.compile(r"^( +)(?=\S|$|\n)") fixed_lines = [] for line in markdown.splitlines(keepends=True): match = leading_spaces_pattern.match(line) if match and len(match.group(0)) % 2 != 0: line = " " + line fixed_lines.append(line) markdown = "".join(fixed_lines) if self.newline_before_list: markdown = re.sub(r"(?<=[^\n])\n(?=[ \t]*(?:[*+-]\s|\d+\.\s))", "\n\n", markdown) return markdown def to_markdown(text: str, **kwargs) -> str: """Convert text to Markdown using `ToMarkdown`.""" return ToMarkdown(**kwargs).to_markdown(text) class ToHTML(Configured): """Class to convert Markdown to HTML.""" _expected_keys_mode: tp.ExpectedKeysMode = "disable" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.formatting"] def __init__( self, resolve_extensions: tp.Optional[bool] = None, make_links: tp.Optional[bool] = None, **markdown_kwargs, ) -> None: Configured.__init__( self, resolve_extensions=resolve_extensions, make_links=make_links, **markdown_kwargs, ) resolve_extensions = self.resolve_setting(resolve_extensions, "resolve_extensions") make_links = self.resolve_setting(make_links, "make_links") markdown_kwargs = self.resolve_setting(markdown_kwargs, "markdown_kwargs", merge=True) self._resolve_extensions = resolve_extensions self._make_links = make_links self._markdown_kwargs = markdown_kwargs @property def resolve_extensions(self) -> bool: """Whether to resolve Markdown extensions. Uses `pymdownx` extensions over native extensions if installed.""" return self._resolve_extensions @property def make_links(self) -> bool: """Whether to detect raw URLs in HTML text (`p` and `span` elements only) and convert them to links.""" return self._make_links @property def markdown_kwargs(self) -> tp.Kwargs: """Keyword arguments passed to `markdown.markdown`.""" return self._markdown_kwargs def to_html(self, markdown: str) -> str: """Convert Markdown to HTML.""" from vectorbtpro.utils.module_ import assert_can_import assert_can_import("markdown") import markdown as md markdown_kwargs = dict(self.markdown_kwargs) extensions = markdown_kwargs.pop("extensions", []) if self.resolve_extensions: from vectorbtpro.utils.module_ import check_installed filtered_extensions = [ ext for ext in extensions if "." not in ext or check_installed(ext.partition(".")[0]) ] ext_set = set(filtered_extensions) remove_fenced_code = "fenced_code" in ext_set and "pymdownx.superfences" in ext_set remove_codehilite = "codehilite" in ext_set and "pymdownx.highlight" in ext_set if remove_fenced_code or remove_codehilite: filtered_extensions = [ ext for ext in filtered_extensions if not ( (ext == "fenced_code" and remove_fenced_code) or (ext == "codehilite" and remove_codehilite) ) ] extensions = filtered_extensions html = md.markdown(markdown, extensions=extensions, **markdown_kwargs) if self.make_links: tag_pattern = re.compile(r"<(p|span)(\s[^>]*)?>(.*?)", re.DOTALL | re.IGNORECASE) url_pattern = re.compile(r'(https?://[^\s<>"\'`]+?)(?=[.,;:!?)\]]*(?:\s|$))', re.IGNORECASE) def _replace_urls(match, _url_pattern=url_pattern): tag = match.group(1) attributes = match.group(2) if match.group(2) else "" content = match.group(3) parts = re.split(r"(]*>.*?)", content, flags=re.DOTALL | re.IGNORECASE) for i, part in enumerate(parts): if not re.match(r"]*>.*?", part, re.DOTALL | re.IGNORECASE): part = _url_pattern.sub(r'\1', part) parts[i] = part new_content = "".join(parts) return f"<{tag}{attributes}>{new_content}" html = tag_pattern.sub(_replace_urls, html) return html.strip() def to_html(text: str, **kwargs) -> str: """Convert Markdown to HTML using `ToHTML`.""" return ToHTML(**kwargs).to_html(text) class FormatHTML(Configured): """Class to format HTML. If `use_pygments` is True, uses Pygments package for code highlighting. Arguments in `pygments_kwargs` are then passed to `pygments.formatters.HtmlFormatter`. Use `style_extras` to inject additional CSS rules outside the predefined ones. Use `head_extras` to inject additional HTML elements into the `` section, such as meta tags, links to external stylesheets, or scripts. Use `body_extras` to inject JavaScript files or inline scripts at the end of the ``. All of these arguments can be lists. HTML template is a template that can use all the arguments except those related to pygments. It can be either a custom template, or string or function that will become one.""" _settings_path: tp.SettingsPath = ["knowledge", "knowledge.formatting"] def __init__( self, html_template: tp.Optional[str] = None, style_extras: tp.Optional[tp.MaybeList[str]] = None, head_extras: tp.Optional[tp.MaybeList[str]] = None, body_extras: tp.Optional[tp.MaybeList[str]] = None, invert_colors: tp.Optional[bool] = None, invert_colors_style: tp.Optional[str] = None, auto_scroll: tp.Optional[bool] = None, auto_scroll_body: tp.Optional[str] = None, show_spinner: tp.Optional[bool] = None, spinner_style: tp.Optional[str] = None, spinner_body: tp.Optional[str] = None, use_pygments: tp.Optional[bool] = None, pygments_kwargs: tp.KwargsLike = None, template_context: tp.KwargsLike = None, **kwargs, ) -> None: from vectorbtpro.utils.module_ import check_installed, assert_can_import Configured.__init__( self, html_template=html_template, style_extras=style_extras, head_extras=head_extras, body_extras=body_extras, invert_colors=invert_colors, invert_colors_style=invert_colors_style, auto_scroll=auto_scroll, auto_scroll_body=auto_scroll_body, show_spinner=show_spinner, spinner_style=spinner_style, spinner_body=spinner_body, use_pygments=use_pygments, pygments_kwargs=pygments_kwargs, template_context=template_context, **kwargs, ) html_template = self.resolve_setting(html_template, "html_template") invert_colors = self.resolve_setting(invert_colors, "invert_colors") invert_colors_style = self.resolve_setting(invert_colors_style, "invert_colors_style") auto_scroll = self.resolve_setting(auto_scroll, "auto_scroll") auto_scroll_body = self.resolve_setting(auto_scroll_body, "auto_scroll_body") show_spinner = self.resolve_setting(show_spinner, "show_spinner") spinner_style = self.resolve_setting(spinner_style, "spinner_style") spinner_body = self.resolve_setting(spinner_body, "spinner_body") use_pygments = self.resolve_setting(use_pygments, "use_pygments") pygments_kwargs = self.resolve_setting(pygments_kwargs, "pygments_kwargs", merge=True) template_context = self.resolve_setting(template_context, "template_context", merge=True) def _prepare_extras(extras): if extras is None: extras = [] if isinstance(extras, str): extras = [extras] if not isinstance(extras, list): extras = list(extras) return "\n".join(extras) if isinstance(html_template, str): html_template = SafeSub(html_template) elif checks.is_function(html_template): html_template = RepFunc(html_template) elif not isinstance(html_template, CustomTemplate): raise TypeError(f"HTML template must be a string, function, or template") style_extras = _prepare_extras(self.get_setting("style_extras")) + _prepare_extras(style_extras) head_extras = _prepare_extras(self.get_setting("head_extras")) + _prepare_extras(head_extras) body_extras = _prepare_extras(self.get_setting("body_extras")) + _prepare_extras(body_extras) if invert_colors: style_extras = "\n".join([style_extras, invert_colors_style]) if auto_scroll: body_extras = "\n".join([body_extras, auto_scroll_body]) if show_spinner: style_extras = "\n".join([style_extras, spinner_style]) body_extras = "\n".join([body_extras, spinner_body]) if use_pygments is None: use_pygments = check_installed("pygments") if use_pygments: assert_can_import("pygments") from pygments.formatters import HtmlFormatter formatter = HtmlFormatter(**pygments_kwargs) highlight_css = formatter.get_style_defs(".highlight") if style_extras == "": style_extras = highlight_css else: style_extras = highlight_css + "\n" + style_extras self._html_template = html_template self._style_extras = style_extras self._head_extras = head_extras self._body_extras = body_extras self._template_context = template_context @property def html_template(self) -> CustomTemplate: """HTML template.""" return self._html_template @property def style_extras(self) -> str: """Extras for ` $head_extras $html_metadata $html_content $body_extras """, root_style_extras=[], style_extras=[], head_extras=[], body_extras=[ r"""""", r"""""", r"""""", r"""""", r"""""", ], invert_colors=False, invert_colors_style=""":root { filter: invert(100%); }""", auto_scroll=False, auto_scroll_body="""""", show_spinner=False, spinner_style=""".loader { width: 300px; height: 5px; margin: 0 auto; display: block; position: relative; overflow: hidden; } .loader::after { content: ''; width: 300px; height: 5px; background: blue; position: absolute; top: 0; left: 0; box-sizing: border-box; animation: animloader 1s ease-in-out infinite; } @keyframes animloader { 0%, 5% { left: 0; transform: translateX(-100%); } 95%, 100% { left: 100%; transform: translateX(0%); } } """, spinner_body="""""", output_to=None, flush_output=True, buffer_output=True, close_output=None, update_interval=None, minimal_format=False, formatter="ipython_auto", formatter_config=flex_cfg(), formatter_configs=flex_cfg( plain=flex_cfg(), ipython=flex_cfg(), ipython_markdown=flex_cfg(), ipython_html=flex_cfg(), html=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'html'"), mkdir_kwargs=flex_cfg(), temp_files=False, refresh_page=True, file_prefix_len=20, file_suffix_len=6, auto_scroll=True, show_spinner=True, ), ), ), chat=flex_cfg( chat_dir=RepEval("Path(cache_dir) / 'chat'"), stream=True, to_context_kwargs=flex_cfg(), incl_past_queries=True, rank=None, rank_kwargs=flex_cfg( top_k=None, min_top_k=None, max_top_k=None, cutoff=None, return_chunks=False, ), max_tokens=120_000, system_prompt=r"You are a helpful assistant. Given the context information and not prior knowledge, answer the query.", system_as_user=True, context_prompt=r"""Context information is below. --------------------- $context ---------------------""", minimal_format=True, tokenizer="tiktoken", tokenizer_config=flex_cfg(), tokenizer_configs=flex_cfg( tiktoken=flex_cfg( encoding="model_or_o200k_base", model=None, tokens_per_message=3, tokens_per_name=1, ), ), embeddings="auto", embeddings_config=flex_cfg( batch_size=512, ), embeddings_configs=flex_cfg( openai=flex_cfg( model="text-embedding-3-large", dimensions=256, ), litellm=flex_cfg( model="text-embedding-3-large", dimensions=256, ), llama_index=flex_cfg( embedding="openai", embedding_configs=flex_cfg( openai=flex_cfg( model="text-embedding-3-large", dimensions=256, ) ), ), ), completions="auto", completions_config=flex_cfg(), completions_configs=flex_cfg( openai=flex_cfg( model="gpt-4o", ), litellm=flex_cfg( model="gpt-4o", ), llama_index=flex_cfg( llm="openai", llm_configs=flex_cfg( openai=flex_cfg( model="gpt-4o", ) ), ), ), text_splitter="segment", text_splitter_config=flex_cfg( chunk_template=r"""... (previous text omitted) $chunk_text""", ), text_splitter_configs=flex_cfg( token=flex_cfg( chunk_size=800, chunk_overlap=400, tokenizer="tiktoken", tokenizer_kwargs=flex_cfg( encoding="cl100k_base", ), ), segment=flex_cfg( separators=[[r"\n\s*\n", r"(?<=[^\s.?!])[.?!]+(?:\s+|$)"], r"\s+", None], min_chunk_size=0.8, fixed_overlap=False, ), llama_index=flex_cfg( node_parser="sentence", node_parser_configs=flex_cfg(), ), ), obj_store="memory", obj_store_config=flex_cfg( store_id="default", purge_on_open=False, ), obj_store_configs=flex_cfg( memory=flex_cfg(), file=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'file_store'"), compression=None, save_kwargs=flex_cfg( mkdir_kwargs=flex_cfg(), ), load_kwargs=flex_cfg(), use_patching=True, consolidate=False, mirror=True, ), lmdb=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'lmdb_store'"), mkdir_kwargs=flex_cfg(), dumps_kwargs=flex_cfg(), loads_kwargs=flex_cfg(), mirror=True, flag="c", ), cached=flex_cfg( lazy_open=True, mirror=False, ), ), doc_ranker_config=flex_cfg( dataset_id=None, cache_doc_store=True, cache_emb_store=True, doc_store_configs=flex_cfg( memory=flex_cfg( store_id="doc_default", ), file=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'doc_file_store'"), ), lmdb=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'doc_lmdb_store'"), ), ), emb_store_configs=flex_cfg( memory=flex_cfg( store_id="emb_default", ), file=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'emb_file_store'"), ), lmdb=flex_cfg( dir_path=RepEval("Path(cache_dir) / 'emb_lmdb_store'"), ), ), score_func="cosine", score_agg_func="mean", ), ), assets=flex_cfg( vbt=flex_cfg( cache_dir="./knowledge/vbt/", release_dir=RepEval("(Path(cache_dir) / release_name) if release_name else cache_dir"), assets_dir=RepEval("Path(release_dir) / 'assets'"), markdown_dir=RepEval("Path(release_dir) / 'markdown'"), html_dir=RepEval("Path(release_dir) / 'html'"), release_name=None, asset_name=None, repo_owner="polakowo", repo_name="vectorbt.pro", token=None, token_required=False, use_pygithub=None, chunk_size=8192, document_cls=None, document_kwargs=flex_cfg( text_path="content", excl_metadata=RepEval("asset_cls.get_setting('minimize_keys')"), excl_embed_metadata=True, split_text_kwargs=flex_cfg(), ), minimize_metadata=False, minimize_keys=[ "parent", "children", "type", "icon", "tags", "block", "thread", "replies", "mentions", "reactions", ], minimize_links=False, minimize_link_rules=flex_cfg( { r"(https://vectorbt\.pro/pvt_[a-zA-Z0-9]+)": "$pvt_site", r"(https://vectorbt\.pro)": "$pub_site", r"(https://discord\.com/channels/[0-9]+)": "$discord", r"(https://github\.com/polakowo/vectorbt\.pro)": "$github", } ), root_metadata_key=None, aggregate_fields=False, parent_links_only=True, clean_metadata=True, clean_metadata_kwargs=flex_cfg(), dump_metadata_kwargs=flex_cfg(), incl_base_attr=True, incl_shortcuts=True, incl_shortcut_access=True, incl_shortcut_call=True, incl_instances=True, incl_custom=None, is_custom_regex=False, as_code=False, as_regex=True, allow_prefix=False, allow_suffix=False, merge_targets=True, display=flex_cfg( html_template=r""" $title $head_extras $body_extras """, style_extras=[], head_extras=[], body_extras=[], ), chat=flex_cfg( chat_dir=RepEval("Path(release_dir) / 'chat'"), system_prompt=r"""You are a helpful assistant with access to VectorBT PRO (also called VBT or vectorbtpro) documentation and relevant Discord history. Use only this provided context to generate clear, accurate answers. Do not reference the open‑source vectorbt, as VectorBT PRO is a proprietary successor with significant differences.\n\nWhen coding in Python, use:\n```python\nimport vectorbtpro as vbt\n```\n\nIf metadata includes links, reference them to support your answer. Do not include external or fabricated links, and exclude any information not present in the given context.\n\nFor each query, follow this structure:\n1. Optionally restate the question in your own words.\n2. Answer using only the available context.\n3. Include any relevant links.""", doc_ranker_config=flex_cfg( doc_store="lmdb", doc_store_configs=flex_cfg( file=flex_cfg( dir_path=RepEval("Path(release_dir) / 'doc_file_store'"), ), lmdb=flex_cfg( dir_path=RepEval("Path(release_dir) / 'doc_lmdb_store'"), ), ), emb_store="lmdb", emb_store_configs=flex_cfg( file=flex_cfg( dir_path=RepEval("Path(release_dir) / 'emb_file_store'"), ), lmdb=flex_cfg( dir_path=RepEval("Path(release_dir) / 'emb_lmdb_store'"), ), ), ), ), ), pages=flex_cfg( assets_dir=RepEval("Path(release_dir) / 'pages' / 'assets'"), markdown_dir=RepEval("Path(release_dir) / 'pages' / 'markdown'"), html_dir=RepEval("Path(release_dir) / 'pages' / 'html'"), asset_name="pages.json.zip", token_required=True, append_obj_type=True, append_github_link=True, use_parent=None, use_base_parents=False, use_ref_parents=False, incl_bases=True, incl_ancestors=True, incl_base_ancestors=False, incl_refs=None, incl_descendants=True, incl_ancestor_descendants=False, incl_ref_descendants=False, aggregate=True, aggregate_ancestors=False, aggregate_refs=False, topo_sort=True, incl_pages=None, excl_pages=None, page_find_mode="substring", up_aggregate=True, up_aggregate_th=2 / 3, up_aggregate_pages=True, ), messages=flex_cfg( assets_dir=RepEval("Path(release_dir) / 'messages' / 'assets'"), markdown_dir=RepEval("Path(release_dir) / 'messages' / 'markdown'"), html_dir=RepEval("Path(release_dir) / 'messages' / 'html'"), asset_name="messages.json.zip", token_required=True, ), ), ) """_""" __pdoc__["knowledge"] = Sub( """Sub-config with settings applied across `vectorbtpro.utils.knowledge`. ```python ${config_doc} ```""" ) _settings["knowledge"] = knowledge # ############# Settings config ############# # class SettingsConfig(Config): """Extends `vectorbtpro.utils.config.Config` for global settings.""" def __init__( self, *args, **kwargs, ) -> None: options_ = kwargs.pop("options_", None) if options_ is None: options_ = {} copy_kwargs = options_.pop("copy_kwargs", None) if copy_kwargs is None: copy_kwargs = {} copy_kwargs["copy_mode"] = "deep" options_["copy_kwargs"] = copy_kwargs options_["frozen_keys"] = True options_["as_attrs"] = True Config.__init__(self, *args, options_=options_, **kwargs) def register_template(self, theme: str) -> None: """Register template of a theme.""" if check_installed("plotly"): import plotly.io as pio import plotly.graph_objects as go template_path = self["plotting"]["themes"][theme]["path"] if template_path is None: raise ValueError(f"Must provide template path for the theme '{theme}'") if template_path.startswith("__name__/"): template_path = template_path.replace("__name__/", "") template = Config(json.loads(pkgutil.get_data(__name__, template_path))) else: with open(template_path, "r") as f: template = Config(json.load(f)) pio.templates["vbt_" + theme] = go.layout.Template(template) def register_templates(self) -> None: """Register templates of all themes.""" for theme in self["plotting"]["themes"]: self.register_template(theme) def set_theme(self, theme: str) -> None: """Set default theme.""" self.register_template(theme) self["plotting"]["color_schema"].update(self["plotting"]["themes"][theme]["color_schema"]) self["plotting"]["layout"]["template"] = "vbt_" + theme def reset_theme(self) -> None: """Reset to default theme.""" self.set_theme(self["plotting"]["default_theme"]) def substitute_sub_config_docs(self, __pdoc__: dict, prettify_kwargs: tp.KwargsLike = None) -> None: """Substitute templates in sub-config docs.""" if prettify_kwargs is None: prettify_kwargs = {} for k, v in __pdoc__.items(): if k in self: config_doc = self[k].prettify(**prettify_kwargs.get(k, {})) __pdoc__[k] = substitute_templates( v, context=dict(config_doc=config_doc), eval_id="__pdoc__", ) def get(self, key: tp.PathLikeKey, default: tp.Any = MISSING) -> tp.Any: """Get setting(s) under a path. See `vectorbtpro.utils.search_.get_pathlike_key` for path format.""" from vectorbtpro.utils.search_ import get_pathlike_key try: return get_pathlike_key(self, key) except (KeyError, IndexError, AttributeError) as e: if default is MISSING: raise e return default def set(self, key: tp.PathLikeKey, value: tp.Any, default_config_type: tp.Type[Config] = flex_cfg) -> None: """Set setting(s) under a path. See `vectorbtpro.utils.search_.get_pathlike_key` for path format.""" from vectorbtpro.utils.search_ import resolve_pathlike_key tokens = resolve_pathlike_key(key) obj = self for i, token in enumerate(tokens): if isinstance(obj, Config): if token not in obj: obj[token] = default_config_type() if i < len(tokens) - 1: if isinstance(obj, (set, frozenset)): obj = list(obj)[token] elif hasattr(obj, "__getitem__"): obj = obj[token] elif isinstance(token, str) and hasattr(obj, token): obj = getattr(obj, token) else: raise TypeError(f"Cannot navigate object of type {type(obj).__name__}") else: if hasattr(obj, "__setitem__"): obj[token] = value elif hasattr(obj, "__dict__"): setattr(obj, token, value) else: raise TypeError(f"Cannot modify object of type {type(obj).__name__}") settings = SettingsConfig(_settings) """Global settings config. Combines all sub-configs defined in this module.""" settings_name = os.environ.get("VBT_SETTINGS_NAME", "vbt") if "VBT_SETTINGS_PATH" in os.environ: if len(os.environ["VBT_SETTINGS_PATH"]) > 0: settings.load_update(os.environ["VBT_SETTINGS_PATH"]) elif settings.file_exists(settings_name): settings.load_update(settings_name) settings.reset_theme() settings.register_templates() settings.make_checkpoint() settings.substitute_sub_config_docs(__pdoc__) if settings["numba"]["disable"]: nb_config.DISABLE_JIT = True # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """General types used across vectorbtpro.""" from datetime import datetime, timedelta, tzinfo, date, time from enum import EnumMeta from pathlib import Path import sys if sys.version_info < (3, 9): import typing typing.__all__.append("TextIO") from typing import * import numpy as np import pandas as pd from mypy_extensions import VarArg from pandas import Series, DataFrame as Frame, Index from pandas.core.groupby import GroupBy as PandasGroupBy from pandas.core.indexing import _IndexSlice as IndexSlice from pandas.core.resample import Resampler as PandasResampler from pandas.tseries.offsets import BaseOffset try: if not TYPE_CHECKING: raise ImportError from plotly.graph_objects import Figure, FigureWidget from plotly.basedatatypes import BaseFigure, BaseTraceType except ImportError: Figure = Any FigureWidget = Any BaseFigure = Any BaseTraceType = Any try: from typing import Protocol except ImportError: from typing_extensions import Protocol try: from typing import Self except ImportError: from typing_extensions import Self if TYPE_CHECKING: from vectorbtpro.utils.parsing import Regex from vectorbtpro.utils.execution import Task, ExecutionEngine from vectorbtpro.utils.chunking import Sizer, NotChunked, ChunkTaker, ChunkMeta, ChunkMetaGenerator from vectorbtpro.utils.jitting import Jitter from vectorbtpro.utils.template import CustomTemplate from vectorbtpro.utils.datetime_ import DTC, DTCNT from vectorbtpro.utils.selection import PosSel, LabelSel from vectorbtpro.utils.merging import MergeFunc from vectorbtpro.utils.knowledge.base_asset_funcs import AssetFunc from vectorbtpro.utils.knowledge.asset_pipelines import AssetPipeline from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset from vectorbtpro.utils.knowledge.chatting import ( Tokenizer, Embeddings, Completions, TextSplitter, StoreDocument, ObjectStore, EmbeddedDocument, ScoredDocument, ) from vectorbtpro.utils.knowledge.custom_assets import VBTAsset, PagesAsset, MessagesAsset from vectorbtpro.utils.knowledge.formatting import ContentFormatter from vectorbtpro.base.indexing import hslice from vectorbtpro.base.grouping.base import Grouper from vectorbtpro.base.resampling.base import Resampler from vectorbtpro.generic.splitting.base import FixRange, RelRange else: Regex = "Regex" Task = "Task" ExecutionEngine = "ExecutionEngine" Sizer = "Sizer" NotChunked = "NotChunked" ChunkTaker = "ChunkTaker" ChunkMeta = "ChunkMeta" ChunkMetaGenerator = "ChunkMetaGenerator" TraceUpdater = "TraceUpdater" Jitter = "Jitter" CustomTemplate = "CustomTemplate" DTC = "DTC" DTCNT = "DTCNT" PosSel = "PosSel" LabelSel = "LabelSel" MergeFunc = "MergeFunc" AssetFunc = "AssetFunc" AssetPipeline = "AssetPipeline" KnowledgeAsset = "KnowledgeAsset" Tokenizer = "Tokenizer" Embeddings = "Embeddings" Completions = "Completions" TextSplitter = "TextSplitter" StoreDocument = "StoreDocument" ObjectStore = "ObjectStore" EmbeddedDocument = "EmbeddedDocument" ScoredDocument = "ScoredDocument" VBTAsset = "VBTAsset" PagesAsset = "PagesAsset" MessagesAsset = "MessagesAsset" ContentFormatter = "ContentFormatter" hslice = "hslice" Grouper = "Grouper" Resampler = "Resampler" FixRange = "FixRange" RelRange = "RelRange" __all__ = [] # Generic types T = TypeVar("T") F = TypeVar("F", bound=Callable[..., Any]) MaybeType = Union[T, Type[T]] # Scalars Scalar = Union[str, float, int, complex, bool, object, np.generic] Number = Union[int, float, complex, np.number, np.bool_] Int = Union[int, np.integer] Float = Union[float, np.floating] IntFloat = Union[Int, Float] IntStr = Union[Int, str] # Basic sequences MaybeTuple = Union[T, Tuple[T, ...]] MaybeList = Union[T, List[T]] MaybeSet = Union[T, Set[T]] TupleList = Union[List[T], Tuple[T, ...]] MaybeTupleList = Union[T, List[T], Tuple[T, ...]] MaybeIterable = Union[T, Iterable[T]] MaybeSequence = Union[T, Sequence[T]] MaybeDict = Union[Dict[Hashable, T], T] MappingSequence = Union[Mapping[Hashable, T], Sequence[T]] MaybeMappingSequence = Union[T, Mapping[Hashable, T], Sequence[T]] SetLike = Union[None, Set[T]] Items = Iterator[Tuple[Hashable, Any]] # Arrays class SupportsArrayT(Protocol): def __array__(self) -> np.ndarray: ... DTypeLike = Any PandasDTypeLike = Any TypeLike = MaybeIterable[Union[Type, str, Regex]] Shape = Tuple[int, ...] ShapeLike = Union[int, Shape] Array = np.ndarray # ready to be used for n-dim data Array1d = np.ndarray Array2d = np.ndarray Array3d = np.ndarray Record = np.void RecordArray = np.ndarray RecordArray2d = np.ndarray RecArray = np.recarray MaybeArray = Union[Scalar, Array] MaybeIndexArray = Union[int, slice, Array1d, Array2d] SeriesFrame = Union[Series, Frame] MaybeSeries = Union[Scalar, Series] MaybeSeriesFrame = Union[T, Series, Frame] PandasArray = Union[Index, Series, Frame] AnyArray = Union[Array, PandasArray] AnyArray1d = Union[Array1d, Index, Series] AnyArray2d = Union[Array2d, Frame] ArrayLike = Union[Scalar, Sequence[Scalar], Sequence[Sequence[Any]], SupportsArrayT, Array] IndexLike = Union[range, Sequence[Scalar], SupportsArrayT] FlexArray1d = Array1d FlexArray2d = Array2d FlexArray1dLike = Union[Scalar, Array1d, Array2d] FlexArray2dLike = Union[Scalar, Array1d, Array2d] # Templates CustomTemplateLike = Union[str, Callable, CustomTemplate] # Labels Label = Hashable Labels = Sequence[Label] Level = Union[str, int] LevelSequence = Sequence[Level] MaybeLevelSequence = Union[Level, LevelSequence] # Datetime Datetime = Union[pd.Timestamp, np.datetime64, datetime] DatetimeLike = Union[str, int, float, Datetime] Timedelta = Union[pd.Timedelta, np.timedelta64, timedelta] TimedeltaLike = Union[str, int, float, Timedelta] Frequency = Union[BaseOffset, Timedelta] FrequencyLike = Union[BaseOffset, TimedeltaLike] TimezoneLike = Union[None, str, int, float, timedelta, tzinfo] TimeLike = Union[str, time] PandasFrequency = Union[BaseOffset, pd.Timedelta] PandasDatetimeIndex = Union[pd.DatetimeIndex, pd.PeriodIndex] AnyPandasFrequency = Union[None, int, float, PandasFrequency] DTCLike = Union[None, str, int, time, date, Datetime, DTC, DTCNT] # Indexing Slice = Union[slice, hslice] PandasIndexingFunc = Callable[[SeriesFrame], MaybeSeriesFrame] # Grouping PandasGroupByLike = Union[PandasGroupBy, PandasResampler, FrequencyLike] GroupByLike = Union[None, bool, MaybeLevelSequence, IndexLike, CustomTemplate] AnyGroupByLike = Union[Grouper, PandasGroupByLike, GroupByLike] AnyRuleLike = Union[Resampler, PandasResampler, FrequencyLike, IndexLike] GroupIdxs = Array1d GroupLens = Array1d GroupMap = Tuple[GroupIdxs, GroupLens] # Wrapping NameIndex = Union[None, Any, Index] # Search PathKeyToken = Hashable PathKeyTokens = Sequence[Hashable] PathKey = Tuple[PathKeyToken, ...] MaybePathKey = Union[None, PathKeyToken, PathKey] PathLikeKey = Union[MaybePathKey, Path] PathLikeKeys = Sequence[PathLikeKey] PathMoveDict = Dict[PathLikeKey, PathLikeKey] PathRenameDict = Dict[PathLikeKey, PathKeyToken] PathDict = Dict[PathLikeKey, Any] # Config DictLike = Union[None, dict] DictLikeSequence = MaybeSequence[DictLike] Args = Tuple[Any, ...] ArgsLike = Union[None, Args] Kwargs = Dict[str, Any] KwargsLike = Union[None, Kwargs] KwargsLikeSequence = MaybeSequence[KwargsLike] ArgsKwargs = Tuple[Args, Kwargs] PathLike = Union[str, Path] _SettingsPath = Union[None, MaybeList[PathLikeKey], Dict[Hashable, PathLikeKey]] SettingsPath = ClassVar[_SettingsPath] ExtSettingsPaths = List[Tuple[type, _SettingsPath]] SpecSettingsPaths = Dict[PathLikeKey, MaybeList[PathLikeKey]] WriteableAttrs = ClassVar[Optional[Set[str]]] ExpectedKeysMode = ClassVar[str] ExpectedKeys = ClassVar[Optional[Set[str]]] # Data Column = Key = Feature = Symbol = Hashable Columns = Keys = Features = Symbols = Sequence[Hashable] MaybeColumns = MaybeKeys = MaybeFeatures = MaybeSymbols = Union[Hashable, Sequence[Hashable]] KeyData = FeatureData = SymbolData = Union[None, SeriesFrame, Tuple[SeriesFrame, Kwargs]] # Plotting TraceName = Union[str, None] TraceNames = MaybeSequence[TraceName] # Generic MapFunc = Callable[[Scalar, VarArg()], Scalar] MapMetaFunc = Callable[[int, int, Scalar, VarArg()], Scalar] ApplyFunc = Callable[[Array1d, VarArg()], MaybeArray] ApplyMetaFunc = Callable[[int, VarArg()], MaybeArray] ReduceFunc = Callable[[Array1d, VarArg()], Scalar] ReduceMetaFunc = Callable[[int, VarArg()], Scalar] ReduceToArrayFunc = Callable[[Array1d, VarArg()], Array1d] ReduceToArrayMetaFunc = Callable[[int, VarArg()], Array1d] ReduceGroupedFunc = Callable[[Array2d, VarArg()], Scalar] ReduceGroupedMetaFunc = Callable[[GroupIdxs, int, VarArg()], Scalar] ReduceGroupedToArrayFunc = Callable[[Array2d, VarArg()], Array1d] ReduceGroupedToArrayMetaFunc = Callable[[GroupIdxs, int, VarArg()], Array1d] RangeReduceMetaFunc = Callable[[int, int, int, VarArg()], Scalar] ProximityReduceMetaFunc = Callable[[int, int, int, int, VarArg()], Scalar] GroupByReduceMetaFunc = Callable[[GroupIdxs, int, int, VarArg()], Scalar] GroupSqueezeMetaFunc = Callable[[int, GroupIdxs, int, VarArg()], Scalar] GroupByTransformFunc = Callable[[Array2d, VarArg()], MaybeArray] GroupByTransformMetaFunc = Callable[[GroupIdxs, int, VarArg()], MaybeArray] # Signals PlaceFunc = Callable[[NamedTuple, VarArg()], int] RankFunc = Callable[[NamedTuple, VarArg()], int] # Records RecordsMapFunc = Callable[[np.void, VarArg()], Scalar] RecordsMapMetaFunc = Callable[[int, VarArg()], Scalar] MappedReduceMetaFunc = Callable[[GroupIdxs, int, VarArg()], Scalar] MappedReduceToArrayMetaFunc = Callable[[GroupIdxs, int, VarArg()], Array1d] # Indicators ParamValue = Any ParamValues = Sequence[ParamValue] MaybeParamValues = MaybeSequence[ParamValue] MaybeParams = Sequence[MaybeParamValues] Params = Sequence[ParamValues] ParamsOrLens = Sequence[Union[ParamValues, int]] ParamsOrDict = Union[Params, Dict[Hashable, ParamValues]] ParamGrid = Union[ParamsOrLens, Dict[Hashable, ParamsOrLens]] ParamComb = Sequence[ParamValue] ParamCombOrDict = Union[ParamComb, Dict[Hashable, ParamValue]] # Mappings MappingLike = Union[str, Mapping, NamedTuple, EnumMeta, IndexLike] RecordsLike = Union[SeriesFrame, RecordArray, Sequence[MappingLike]] # Annotations Annotation = object Annotations = Dict[str, Annotation] # Parsing AnnArgs = Dict[str, Kwargs] FlatAnnArgs = Dict[str, Kwargs] AnnArgQuery = Union[int, str, Regex] # Execution FuncArgs = Tuple[Callable, Args, Kwargs] FuncsArgs = Iterable[FuncArgs] TaskLike = Union[FuncArgs, Task] TasksLike = Iterable[TaskLike] ExecutionEngineLike = Union[str, type, ExecutionEngine, Callable] ExecResult = Any ExecResults = List[Any] # JIT JittedOption = Union[None, bool, str, Callable, Kwargs] JitterLike = Union[str, Jitter, Type[Jitter]] TaskId = Union[Hashable, Callable] # Merging MergeFuncLike = MaybeSequence[Union[None, str, Callable, MergeFunc]] MergeResult = Any MergeableResults = Union[ExecResults, MergeResult] # Chunking SizeFunc = Callable[[AnnArgs], int] SizeLike = Union[int, str, Sizer, SizeFunc] ChunkMetaFunc = Callable[[AnnArgs], Iterable[ChunkMeta]] ChunkMetaLike = Union[Iterable[ChunkMeta], ChunkMetaGenerator, ChunkMetaFunc] TakeSpec = Union[None, Type[NotChunked], Type[ChunkTaker], NotChunked, ChunkTaker] ArgTakeSpec = Mapping[AnnArgQuery, TakeSpec] ArgTakeSpecFunc = Callable[[AnnArgs, ChunkMeta], Tuple[Args, Kwargs]] ArgTakeSpecLike = Union[Sequence[TakeSpec], ArgTakeSpec, ArgTakeSpecFunc, CustomTemplate] MappingTakeSpec = Mapping[Hashable, TakeSpec] SequenceTakeSpec = Sequence[TakeSpec] ContainerTakeSpec = Union[MappingTakeSpec, SequenceTakeSpec] ChunkedOption = Union[None, bool, str, Callable, Kwargs] # Decorators ClassWrapper = Callable[[Type[T]], Type[T]] FlexClassWrapper = Union[Callable[[Type[T]], Type[T]], Type[T]] # Splitting FixRangeLike = Union[Slice, Sequence[int], Sequence[bool], Callable, CustomTemplate, FixRange] RelRangeLike = Union[int, float, Callable, CustomTemplate, RelRange] RangeLike = Union[FixRangeLike, RelRangeLike] ReadyRangeLike = Union[slice, Array1d] FixSplit = Sequence[FixRangeLike] SplitLike = Union[str, int, float, MaybeSequence[RangeLike]] Splits = Sequence[SplitLike] SplitsArray = Array2d SplitsMask = Array3d BoundsArray = Array3d # Staticization StaticizedOption = Union[None, bool, Kwargs, TaskId] # Selection Selection = Union[PosSel, LabelSel, MaybeIterable[Union[PosSel, LabelSel, Hashable]]] # Knowledge AssetFuncLike = Union[str, Type[AssetFunc], FuncArgs, Task, Callable] MaybeAsset = Union[None, T, dict, list, Iterator[T]] MaybeKnowledgeAsset = MaybeAsset[KnowledgeAsset] MaybeVBTAsset = MaybeAsset[VBTAsset] MaybePagesAsset = MaybeAsset[PagesAsset] MaybeMessagesAsset = MaybeAsset[MessagesAsset] ContentFormatterLike = Union[None, str, MaybeType[ContentFormatter]] TokenizerLike = Union[None, str, MaybeType[Tokenizer]] Token = int Tokens = List[Token] EmbeddingsLike = Union[None, str, MaybeType[Embeddings]] CompletionsLike = Union[None, str, MaybeType[Completions]] ChatMessage = dict ChatMessages = List[ChatMessage] ChatHistory = MutableSequence[ChatMessage] ChatOutput = Union[Optional[Path], Tuple[Optional[Path], Any]] MaybeChatOutput = Union[ChatOutput, Tuple[ChatOutput, Completions]] TextSplitterLike = Union[None, str, MaybeType[TextSplitter]] TSRange = Tuple[int, int] TSRangeChunks = Iterator[TSRange] TSSegment = Tuple[int, int, bool] TSSegmentChunks = Iterator[TSRange] TSTextChunks = Iterator[str] ObjectStoreLike = Union[None, str, MaybeType[ObjectStore]] EmbeddedDocuments = List[EmbeddedDocument] ScoredDocuments = List[Union[float, ScoredDocument]] RankedDocuments = List[Union[StoreDocument, ScoredDocument]] TopKLike = Union[None, int, float, str, Callable] # Chaining PipeFunc = Union[str, Callable, Tuple[Union[str, Callable], str]] PipeTask = Union[PipeFunc, Tuple[PipeFunc, Args, Kwargs], Task] PipeTasks = Iterable[PipeTask] # Pickling CompressionLike = Union[None, bool, str] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== try: try: import tomllib except ModuleNotFoundError: import tomli as tomllib from pathlib import Path with open(Path(__file__).resolve().parent.parent / "pyproject.toml", "rb") as f: pyproject = tomllib.load(f) __version__ = pyproject["project"]["version"] except Exception as e: import importlib.metadata __version__ = importlib.metadata.version(__package__ or __name__) __release__ = "v" + __version__ __all__ = [ "__version__", "__release__", ] # ==================================== VBTPROXYZ ==================================== # Copyright (c) 2021-2025 Oleg Polakow. All rights reserved. # # This file is part of the proprietary VectorBT® PRO package and is licensed under # the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/ # # Unauthorized publishing, distribution, sublicensing, or sale of this software # or its parts is strictly prohibited. # =================================================================================== """Root Pandas accessors of vectorbtpro. An accessor adds additional "namespace" to pandas objects. The `vectorbtpro.accessors` registers a custom `vbt` accessor on top of each `pd.Index`, `pd.Series`, and `pd.DataFrame` object. It is the main entry point for all other accessors: ```plaintext vbt.base.accessors.BaseIDX/SR/DFAccessor -> pd.Index/Series/DataFrame.vbt.* vbt.generic.accessors.GenericSR/DFAccessor -> pd.Series/DataFrame.vbt.* vbt.signals.accessors.SignalsSR/DFAccessor -> pd.Series/DataFrame.vbt.signals.* vbt.returns.accessors.ReturnsSR/DFAccessor -> pd.Series/DataFrame.vbt.returns.* vbt.ohlcv.accessors.OHLCVDFAccessor -> pd.DataFrame.vbt.ohlcv.* vbt.px.accessors.PXSR/DFAccessor -> pd.Series/DataFrame.vbt.px.* ``` Additionally, some accessors subclass other accessors building the following inheritance hiearchy: ```plaintext vbt.base.accessors.BaseIDXAccessor vbt.base.accessors.BaseSR/DFAccessor -> vbt.generic.accessors.GenericSR/DFAccessor -> vbt.signals.accessors.SignalsSR/DFAccessor -> vbt.returns.accessors.ReturnsSR/DFAccessor -> vbt.ohlcv.accessors.OHLCVDFAccessor -> vbt.px.accessors.PXSR/DFAccessor ``` So, for example, the method `pd.Series.vbt.to_2d_array` is also available as `pd.Series.vbt.returns.to_2d_array`. Class methods of any accessor can be conveniently accessed using `pd_acc`, `sr_acc`, and `df_acc` shortcuts: ```pycon >>> from vectorbtpro import * >>> vbt.pd_acc.signals.generate > ``` !!! note Accessors in vectorbt are not cached, so querying `df.vbt` twice will also call `Vbt_DFAccessor` twice. You can change this in global settings.""" import pandas as pd from pandas.core.accessor import DirNamesMixin from vectorbtpro import _typing as tp from vectorbtpro.base.accessors import BaseIDXAccessor from vectorbtpro.base.wrapping import ArrayWrapper from vectorbtpro.generic.accessors import GenericAccessor, GenericSRAccessor, GenericDFAccessor from vectorbtpro.utils.base import Base from vectorbtpro.utils.warnings_ import warn __all__ = [ "Vbt_Accessor", "Vbt_SRAccessor", "Vbt_DFAccessor", "idx_acc", "pd_acc", "sr_acc", "df_acc", ] __pdoc__ = {} ParentAccessorT = tp.TypeVar("ParentAccessorT", bound=object) AccessorT = tp.TypeVar("AccessorT", bound=object) class Accessor(Base): """Accessor.""" def __init__(self, name: str, accessor: tp.Type[AccessorT]) -> None: self._name = name self._accessor = accessor def __get__(self, obj: ParentAccessorT, cls: DirNamesMixin) -> AccessorT: if obj is None: return self._accessor if isinstance(obj, (pd.Index, pd.Series, pd.DataFrame)): accessor_obj = self._accessor(obj) elif issubclass(self._accessor, type(obj)): accessor_obj = obj.replace(cls_=self._accessor) else: accessor_obj = self._accessor(obj.wrapper, obj=obj._obj) return accessor_obj class CachedAccessor(Base): """Cached accessor.""" def __init__(self, name: str, accessor: tp.Type[AccessorT]) -> None: self._name = name self._accessor = accessor def __get__(self, obj: ParentAccessorT, cls: DirNamesMixin) -> AccessorT: if obj is None: return self._accessor if isinstance(obj, (pd.Index, pd.Series, pd.DataFrame)): accessor_obj = self._accessor(obj) elif issubclass(self._accessor, type(obj)): accessor_obj = obj.replace(cls_=self._accessor) else: accessor_obj = self._accessor(obj.wrapper, obj=obj._obj) object.__setattr__(obj, self._name, accessor_obj) return accessor_obj def register_accessor(name: str, cls: tp.Type[DirNamesMixin]) -> tp.Callable: """Register a custom accessor. `cls` must subclass `pandas.core.accessor.DirNamesMixin`.""" def decorator(accessor: tp.Type[AccessorT]) -> tp.Type[AccessorT]: from vectorbtpro._settings import settings caching_cfg = settings["caching"] if hasattr(cls, name): warn( f"registration of accessor {repr(accessor)} under name " f"{repr(name)} for type {repr(cls)} is overriding a preexisting " "attribute with the same name." ) if caching_cfg["use_cached_accessors"]: setattr(cls, name, CachedAccessor(name, accessor)) else: setattr(cls, name, Accessor(name, accessor)) cls._accessors.add(name) return accessor return decorator def register_index_accessor(name: str) -> tp.Callable: """Decorator to register a custom `pd.Index` accessor.""" return register_accessor(name, pd.Index) def register_series_accessor(name: str) -> tp.Callable: """Decorator to register a custom `pd.Series` accessor.""" return register_accessor(name, pd.Series) def register_dataframe_accessor(name: str) -> tp.Callable: """Decorator to register a custom `pd.DataFrame` accessor.""" return register_accessor(name, pd.DataFrame) @register_index_accessor("vbt") class Vbt_IDXAccessor(DirNamesMixin, BaseIDXAccessor): """The main vectorbt accessor for `pd.Index`.""" def __init__(self, obj: tp.Index, **kwargs) -> None: self._obj = obj DirNamesMixin.__init__(self) BaseIDXAccessor.__init__(self, obj, **kwargs) idx_acc = Vbt_IDXAccessor """Shortcut for `Vbt_IDXAccessor`.""" __pdoc__["idx_acc"] = False class Vbt_Accessor(DirNamesMixin, GenericAccessor): """The main vectorbt accessor for `pd.Series` and `pd.DataFrame`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: DirNamesMixin.__init__(self) GenericAccessor.__init__(self, wrapper, obj=obj, **kwargs) pd_acc = Vbt_Accessor """Shortcut for `Vbt_Accessor`.""" __pdoc__["pd_acc"] = False @register_series_accessor("vbt") class Vbt_SRAccessor(DirNamesMixin, GenericSRAccessor): """The main vectorbt accessor for `pd.Series`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: DirNamesMixin.__init__(self) GenericSRAccessor.__init__(self, wrapper, obj=obj, **kwargs) sr_acc = Vbt_SRAccessor """Shortcut for `Vbt_SRAccessor`.""" __pdoc__["sr_acc"] = False @register_dataframe_accessor("vbt") class Vbt_DFAccessor(DirNamesMixin, GenericDFAccessor): """The main vectorbt accessor for `pd.DataFrame`.""" def __init__( self, wrapper: tp.Union[ArrayWrapper, tp.ArrayLike], obj: tp.Optional[tp.ArrayLike] = None, **kwargs, ) -> None: DirNamesMixin.__init__(self) GenericDFAccessor.__init__(self, wrapper, obj=obj, **kwargs) df_acc = Vbt_DFAccessor """Shortcut for `Vbt_DFAccessor`.""" __pdoc__["df_acc"] = False def register_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_Accessor) -> tp.Callable: """Decorator to register an accessor on top of a parent accessor.""" return register_accessor(name, parent) def register_idx_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_IDXAccessor) -> tp.Callable: """Decorator to register a `pd.Index` accessor on top of a parent accessor.""" return register_accessor(name, parent) def register_sr_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_SRAccessor) -> tp.Callable: """Decorator to register a `pd.Series` accessor on top of a parent accessor.""" return register_accessor(name, parent) def register_df_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_DFAccessor) -> tp.Callable: """Decorator to register a `pd.DataFrame` accessor on top of a parent accessor.""" return register_accessor(name, parent)